diff options
-rw-r--r-- | CHANGELOG.md | 4 | ||||
-rw-r--r-- | javascript/aspectRatioOverlay.js | 33 | ||||
-rw-r--r-- | javascript/contextMenus.js | 12 | ||||
-rw-r--r-- | javascript/edit-attention.js | 8 | ||||
-rw-r--r-- | javascript/extensions.js | 12 | ||||
-rw-r--r-- | javascript/extraNetworks.js | 18 | ||||
-rw-r--r-- | javascript/generationParams.js | 2 | ||||
-rw-r--r-- | javascript/hints.js | 4 | ||||
-rw-r--r-- | javascript/hires_fix.js | 16 | ||||
-rw-r--r-- | javascript/imageMaskFix.js | 13 | ||||
-rw-r--r-- | javascript/imageParams.js | 1 | ||||
-rw-r--r-- | javascript/imageviewer.js | 12 | ||||
-rw-r--r-- | javascript/localization.js | 10 | ||||
-rw-r--r-- | javascript/notification.js | 6 | ||||
-rw-r--r-- | javascript/progressbar.js | 7 | ||||
-rw-r--r-- | javascript/ui.js | 36 | ||||
-rw-r--r-- | modules/safe.py | 6 | ||||
-rw-r--r-- | modules/sd_models.py | 54 | ||||
-rw-r--r-- | modules/shared.py | 31 | ||||
-rw-r--r-- | modules/ui.py | 10 | ||||
-rw-r--r-- | webui.py | 16 |
21 files changed, 173 insertions, 138 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b8a3611..8d2f96e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 1.1.1
+### Bug Fixes:
+ * fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
+
## 1.1.0
### Features:
* switch to torch 2.0.0 (except for AMD GPUs)
diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index a8278cca..5160081d 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -45,29 +45,24 @@ function dimensionChange(e, is_width, is_height){ var viewportOffset = targetElement.getBoundingClientRect();
- viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
+ var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
- scaledx = targetElement.naturalWidth*viewportscale
- scaledy = targetElement.naturalHeight*viewportscale
+ var scaledx = targetElement.naturalWidth*viewportscale
+ var scaledy = targetElement.naturalHeight*viewportscale
- cleintRectTop = (viewportOffset.top+window.scrollY)
- cleintRectLeft = (viewportOffset.left+window.scrollX)
- cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
- cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
+ var cleintRectTop = (viewportOffset.top+window.scrollY)
+ var cleintRectLeft = (viewportOffset.left+window.scrollX)
+ var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
+ var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
- viewRectTop = cleintRectCentreY-(scaledy/2)
- viewRectLeft = cleintRectCentreX-(scaledx/2)
- arRectWidth = scaledx
- arRectHeight = scaledy
+ var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
+ var arscaledx = currentWidth*arscale
+ var arscaledy = currentHeight*arscale
- arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
- arscaledx = currentWidth*arscale
- arscaledy = currentHeight*arscale
-
- arRectTop = cleintRectCentreY-(arscaledy/2)
- arRectLeft = cleintRectCentreX-(arscaledx/2)
- arRectWidth = arscaledx
- arRectHeight = arscaledy
+ var arRectTop = cleintRectCentreY-(arscaledy/2)
+ var arRectLeft = cleintRectCentreX-(arscaledx/2)
+ var arRectWidth = arscaledx
+ var arRectHeight = arscaledy
arPreviewRect.style.top = arRectTop+'px';
arPreviewRect.style.left = arRectLeft+'px';
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index 9468c107..42f301ab 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -4,7 +4,7 @@ contextMenuInit = function(){ let menuSpecs = new Map();
const uid = function(){
- return Date.now().toString(36) + Math.random().toString(36).substr(2);
+ return Date.now().toString(36) + Math.random().toString(36).substring(2);
}
function showContextMenu(event,element,menuEntries){
@@ -16,8 +16,7 @@ contextMenuInit = function(){ oldMenu.remove()
}
- let tabButton = uiCurrentTab
- let baseStyle = window.getComputedStyle(tabButton)
+ let baseStyle = window.getComputedStyle(uiCurrentTab)
const contextMenu = document.createElement('nav')
contextMenu.id = "context-menu"
@@ -36,7 +35,7 @@ contextMenuInit = function(){ menuEntries.forEach(function(entry){
let contextMenuEntry = document.createElement('a')
contextMenuEntry.innerHTML = entry['name']
- contextMenuEntry.addEventListener("click", function(e) {
+ contextMenuEntry.addEventListener("click", function() {
entry['func']();
})
contextMenuList.append(contextMenuEntry);
@@ -63,7 +62,7 @@ contextMenuInit = function(){ function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
- currentItems = menuSpecs.get(targetElementSelector)
+ var currentItems = menuSpecs.get(targetElementSelector)
if(!currentItems){
currentItems = []
@@ -79,7 +78,7 @@ contextMenuInit = function(){ }
function removeContextMenuOption(uid){
- menuSpecs.forEach(function(v,k) {
+ menuSpecs.forEach(function(v) {
let index = -1
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
if(index>=0){
@@ -112,7 +111,6 @@ contextMenuInit = function(){ if(e.composedPath()[0].matches(k)){
showContextMenu(e,e.composedPath()[0],v)
e.preventDefault()
- return
}
})
});
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 588c7b77..d2c2f190 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -69,8 +69,8 @@ function keyupEditAttention(event){ event.preventDefault();
- closeCharacter = ')'
- delta = opts.keyedit_precision_attention
+ var closeCharacter = ')'
+ var delta = opts.keyedit_precision_attention
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
closeCharacter = '>'
@@ -91,8 +91,8 @@ function keyupEditAttention(event){ selectionEnd += 1;
}
- end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
- weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
+ var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
+ var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;
weight += isPlus ? delta : -delta;
diff --git a/javascript/extensions.js b/javascript/extensions.js index 3c2f995a..2a2d2f8e 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -1,14 +1,14 @@ -function extensions_apply(_, _, disable_all){
+function extensions_apply(_disabled_list, _update_list, disable_all){
var disable = []
var update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substr(7))
+ disable.push(x.name.substring(7))
if(x.name.startsWith("update_") && x.checked)
- update.push(x.name.substr(7))
+ update.push(x.name.substring(7))
})
restart_reload()
@@ -16,12 +16,12 @@ function extensions_apply(_, _, disable_all){ return [JSON.stringify(disable), JSON.stringify(update), disable_all]
}
-function extensions_check(_, _){
+function extensions_check(){
var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substr(7))
+ disable.push(x.name.substring(7))
})
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
@@ -41,7 +41,7 @@ function install_extension_from_index(button, url){ button.disabled = "disabled"
button.value = "Installing..."
- textarea = gradioApp().querySelector('#extension_to_install textarea')
+ var textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
updateInput(textarea)
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 25322138..c8f6b386 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -10,11 +10,11 @@ function setupExtraNetworksForTab(tabname){ tabs.appendChild(search)
tabs.appendChild(refresh)
- search.addEventListener("input", function(evt){
- searchTerm = search.value.toLowerCase()
+ search.addEventListener("input", function(){
+ var searchTerm = search.value.toLowerCase()
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
- text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
+ var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
})
});
@@ -55,7 +55,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text){ var partToSearch = m[1]
var replaced = false
- var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
+ var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
m = found.match(re_extranet);
if(m[1] == partToSearch){
replaced = true;
@@ -96,9 +96,9 @@ function saveCardPreview(event, tabname, filename){ }
function extraNetworksSearchButton(tabs_id, event){
- searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
- button = event.target
- text = button.classList.contains("search-all") ? "" : button.textContent.trim()
+ var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
+ var button = event.target
+ var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
searchTextarea.value = text
updateInput(searchTextarea)
@@ -133,7 +133,7 @@ function popup(contents){ }
function extraNetworksShowMetadata(text){
- elem = document.createElement('pre')
+ var elem = document.createElement('pre')
elem.classList.add('popup-metadata');
elem.textContent = text;
@@ -165,7 +165,7 @@ function requestGet(url, data, handler, errorHandler){ }
function extraNetworksRequestMetadata(event, extraPage, cardName){
- showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
+ var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
if(data && data.metadata){
diff --git a/javascript/generationParams.js b/javascript/generationParams.js index 1266a266..ef64ee2e 100644 --- a/javascript/generationParams.js +++ b/javascript/generationParams.js @@ -23,7 +23,7 @@ let modalObserver = new MutationObserver(function(mutations) { }); function attachGalleryListeners(tab_name) { - gallery = gradioApp().querySelector('#'+tab_name+'_gallery') + var gallery = gradioApp().querySelector('#'+tab_name+'_gallery') gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); gallery?.addEventListener('keydown', (e) => { if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow diff --git a/javascript/hints.js b/javascript/hints.js index e7d17d36..8d1967a7 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -118,7 +118,9 @@ titles = { onUiUpdate(function(){ gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){ - tooltip = titles[span.textContent]; + if (span.title) return; // already has a title + + let tooltip = titles[span.textContent]; if(!tooltip){ tooltip = titles[span.value]; diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js index 0629475f..48196be4 100644 --- a/javascript/hires_fix.js +++ b/javascript/hires_fix.js @@ -1,16 +1,12 @@ -function setInactive(elem, inactive){
- if(inactive){
- elem.classList.add('inactive')
- } else{
- elem.classList.remove('inactive')
+function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
+ function setInactive(elem, inactive){
+ elem.classList.toggle('inactive', !!inactive)
}
-}
-function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
- hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
- hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
- hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
+ var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
+ var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
+ var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js index 9fe7a603..a612705d 100644 --- a/javascript/imageMaskFix.js +++ b/javascript/imageMaskFix.js @@ -2,11 +2,10 @@ * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 * @see https://github.com/gradio-app/gradio/issues/1721 */ -window.addEventListener( 'resize', () => imageMaskResize()); function imageMaskResize() { const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); if ( ! canvases.length ) { - canvases_fixed = false; + canvases_fixed = false; // TODO: this is unused..? window.removeEventListener( 'resize', imageMaskResize ); return; } @@ -15,7 +14,7 @@ function imageMaskResize() { const previewImage = wrapper.previousElementSibling; if ( ! previewImage.complete ) { - previewImage.addEventListener( 'load', () => imageMaskResize()); + previewImage.addEventListener( 'load', imageMaskResize); return; } @@ -24,7 +23,6 @@ function imageMaskResize() { const nw = previewImage.naturalWidth; const nh = previewImage.naturalHeight; const portrait = nh > nw; - const factor = portrait; const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); @@ -40,6 +38,7 @@ function imageMaskResize() { c.style.maxHeight = '100%'; c.style.objectFit = 'contain'; }); - } - - onUiUpdate(() => imageMaskResize()); +} + +onUiUpdate(imageMaskResize); +window.addEventListener( 'resize', imageMaskResize); diff --git a/javascript/imageParams.js b/javascript/imageParams.js index 67404a89..64aee93b 100644 --- a/javascript/imageParams.js +++ b/javascript/imageParams.js @@ -1,7 +1,6 @@ window.onload = (function(){ window.addEventListener('drop', e => { const target = e.composedPath()[0]; - const idx = selected_gallery_index(); if (target.placeholder.indexOf("Prompt") == -1) return; let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 3deffa9b..32066ab8 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -57,7 +57,7 @@ function modalImageSwitch(offset) { }) if (result != -1) { - nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)] + var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)] nextButton.click() const modalImage = gradioApp().getElementById("modalImage"); const modal = gradioApp().getElementById("lightboxModal"); @@ -144,15 +144,11 @@ function setupImageForLightbox(e) { } function modalZoomSet(modalImage, enable) { - if (enable) { - modalImage.classList.add('modalImageFullscreen'); - } else { - modalImage.classList.remove('modalImageFullscreen'); - } + if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable); } function modalZoomToggle(event) { - modalImage = gradioApp().getElementById("modalImage"); + var modalImage = gradioApp().getElementById("modalImage"); modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen')) event.stopPropagation() } @@ -179,7 +175,7 @@ function galleryImageHandler(e) { } onUiUpdate(function() { - fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img') + var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img') if (fullImg_preview != null) { fullImg_preview.forEach(setupImageForLightbox); } diff --git a/javascript/localization.js b/javascript/localization.js index 1a5a1dbb..e1ffa271 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -35,11 +35,11 @@ function canBeTranslated(node, text){ if(! text) return false;
if(! node.parentElement) return false;
- parentType = node.parentElement.nodeName
+ var parentType = node.parentElement.nodeName
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
if (parentType=='OPTION' || parentType=='SPAN'){
- pnode = node
+ var pnode = node
for(var level=0; level<4; level++){
pnode = pnode.parentElement
if(! pnode) break;
@@ -69,7 +69,7 @@ function getTranslation(text){ }
function processTextNode(node){
- text = node.textContent.trim()
+ var text = node.textContent.trim()
if(! canBeTranslated(node, text)) return
@@ -105,7 +105,7 @@ function processNode(node){ }
function dumpTranslations(){
- dumped = {}
+ var dumped = {}
if (localization.rtl) {
dumped.rtl = true
}
@@ -151,7 +151,7 @@ document.addEventListener("DOMContentLoaded", function() { })
function download_localization() {
- text = JSON.stringify(dumpTranslations(), null, 4)
+ var text = JSON.stringify(dumpTranslations(), null, 4)
var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
diff --git a/javascript/notification.js b/javascript/notification.js index 8ddd4c5d..83fce1f8 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -2,15 +2,15 @@ let lastHeadImg = null; -notificationButton = null +let notificationButton = null; onUiUpdate(function(){ if(notificationButton == null){ notificationButton = gradioApp().getElementById('request_notifications') if(notificationButton != null){ - notificationButton.addEventListener('click', function (evt) { - Notification.requestPermission(); + notificationButton.addEventListener('click', () => { + void Notification.requestPermission(); },true); } } diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 23bbf298..8d2c3492 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,16 +1,15 @@ // code related to showing and updating progressbar shown as the image is being made -function rememberGallerySelection(id_gallery){ +function rememberGallerySelection(){ } -function getGallerySelectedIndex(id_gallery){ +function getGallerySelectedIndex(){ } function request(url, data, handler, errorHandler){ var xhr = new XMLHttpRequest(); - var url = url; xhr.open("POST", url, true); xhr.setRequestHeader("Content-Type", "application/json"); xhr.onreadystatechange = function () { @@ -107,7 +106,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre divProgress.style.width = rect.width + "px"; } - progressText = "" + let progressText = "" divInner.style.width = ((res.progress || 0) * 100.0) + '%' divInner.style.background = res.progress ? "" : "transparent" diff --git a/javascript/ui.js b/javascript/ui.js index bfe31525..b63b84b2 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -1,7 +1,7 @@ // various functions for interaction with ui.py not large enough to warrant putting them in separate files function set_theme(theme){ - gradioURL = window.location.href + var gradioURL = window.location.href if (!gradioURL.includes('?__theme=')) { window.location.replace(gradioURL + '?__theme=' + theme); } @@ -47,7 +47,7 @@ function extract_image_from_gallery(gallery){ return [gallery[0]]; } - index = selected_gallery_index() + var index = selected_gallery_index() if (index < 0 || index >= gallery.length){ // Use the first image in the gallery as the default @@ -58,7 +58,7 @@ function extract_image_from_gallery(gallery){ } function args_to_array(args){ - res = [] + var res = [] for(var i=0;i<args.length;i++){ res.push(args[i]) } @@ -138,7 +138,7 @@ function get_img2img_tab_index() { } function create_submit_args(args){ - res = [] + var res = [] for(var i=0;i<args.length;i++){ res.push(args[i]) } @@ -160,7 +160,7 @@ function showSubmitButtons(tabname, show){ } function showRestoreProgressButton(tabname, show){ - button = gradioApp().getElementById(tabname + "_restore_progress") + var button = gradioApp().getElementById(tabname + "_restore_progress") if(! button) return button.style.display = show ? "flex" : "none" @@ -207,8 +207,9 @@ function submit_img2img(){ return res } -function restoreProgressTxt2img(x){ +function restoreProgressTxt2img(){ showRestoreProgressButton("txt2img", false) + var id = localStorage.getItem("txt2img_task_id") id = localStorage.getItem("txt2img_task_id") @@ -220,10 +221,11 @@ function restoreProgressTxt2img(x){ return id } -function restoreProgressImg2img(x){ - showRestoreProgressButton("img2img", false) - id = localStorage.getItem("img2img_task_id") +function restoreProgressImg2img(){ + showRestoreProgressButton("img2img", false) + + var id = localStorage.getItem("img2img_task_id") if(id) { requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ @@ -252,7 +254,7 @@ function modelmerger(){ function ask_for_style_name(_, prompt_text, negative_prompt_text) { - name_ = prompt('Style name:') + var name_ = prompt('Style name:') return [name_, prompt_text, negative_prompt_text] } @@ -287,11 +289,11 @@ function recalculate_prompts_img2img(){ } -opts = {} +var opts = {} onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; - json_elem = gradioApp().getElementById('settings_json') + var json_elem = gradioApp().getElementById('settings_json') if(json_elem == null) return; var textarea = json_elem.querySelector('textarea') @@ -340,8 +342,8 @@ onUiUpdate(function(){ registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button') registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button') - show_all_pages = gradioApp().getElementById('settings_show_all_pages') - settings_tabs = gradioApp().querySelector('#settings div') + var show_all_pages = gradioApp().getElementById('settings_show_all_pages') + var settings_tabs = gradioApp().querySelector('#settings div') if(show_all_pages && settings_tabs){ settings_tabs.appendChild(show_all_pages) show_all_pages.onclick = function(){ @@ -353,9 +355,9 @@ onUiUpdate(function(){ }) onOptionsChanged(function(){ - elem = gradioApp().getElementById('sd_checkpoint_hash') - sd_checkpoint_hash = opts.sd_checkpoint_hash || "" - shorthash = sd_checkpoint_hash.substr(0,10) + var elem = gradioApp().getElementById('sd_checkpoint_hash') + var sd_checkpoint_hash = opts.sd_checkpoint_hash || "" + var shorthash = sd_checkpoint_hash.substring(0,10) if(elem && elem.textContent != shorthash){ elem.textContent = shorthash diff --git a/modules/safe.py b/modules/safe.py index dadf319c..e6c2f2c0 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -24,7 +24,11 @@ class RestrictedUnpickler(pickle.Unpickler): def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
- return TypedStorage(_internal=True)
+
+ try:
+ return TypedStorage(_internal=True)
+ except TypeError:
+ return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
def find_class(self, module, name):
if self.extra_handler is not None:
diff --git a/modules/sd_models.py b/modules/sd_models.py index 4f7613a1..59adc7cc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,6 +2,8 @@ import collections import os.path
import sys
import gc
+import threading
+
import torch
import re
import safetensors.torch
@@ -404,13 +406,39 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
-def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
+
+class SdModelData:
+ def __init__(self):
+ self.sd_model = None
+ self.lock = threading.Lock()
+
+ def get_sd_model(self):
+ if self.sd_model is None:
+ with self.lock:
+ try:
+ load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load", file=sys.stderr)
+ self.sd_model = None
+
+ return self.sd_model
+
+ def set_sd_model(self, v):
+ self.sd_model = v
+
+
+model_data = SdModelData()
+
+
+def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
- if shared.sd_model:
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
- shared.sd_model = None
+ if model_data.sd_model:
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ model_data.sd_model = None
gc.collect()
devices.torch_gc()
@@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ timer.record("hijack")
sd_model.eval()
- shared.sd_model = sd_model
+ model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None): checkpoint_info = info or select_checkpoint()
if not sd_model:
- sd_model = shared.sd_model
+ sd_model = model_data.sd_model
if sd_model is None: # previous model load failed
current_checkpoint_info = None
@@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None): del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
- return shared.sd_model
+ return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
@@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None): return sd_model
+
def unload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
timer = Timer()
- if shared.sd_model:
-
- # shared.sd_model.cond_stage_model.to(devices.cpu)
- # shared.sd_model.first_stage_model.to(devices.cpu)
- shared.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
- shared.sd_model = None
+ if model_data.sd_model:
+ model_data.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ model_data.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()
diff --git a/modules/shared.py b/modules/shared.py index 6a2b3c2b..151bab9e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,6 +16,7 @@ import modules.styles import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
+from ldm.models.diffusion.ddpm import LatentDiffusion
demo = None
@@ -600,13 +601,37 @@ class Options: return value
-
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
+
+class Shared(sys.modules[__name__].__class__):
+ """
+ this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
+ at program startup.
+ """
+
+ sd_model_val = None
+
+ @property
+ def sd_model(self):
+ import modules.sd_models
+
+ return modules.sd_models.model_data.get_sd_model()
+
+ @sd_model.setter
+ def sd_model(self, value):
+ import modules.sd_models
+
+ modules.sd_models.model_data.set_sd_model(value)
+
+
|