diff options
30 files changed, 473 insertions, 217 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 0ad49f4e..bc11cc6e 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -1,7 +1,6 @@ import os import gc import time -import warnings import numpy as np import torch @@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap from modules import shared, sd_hijack -warnings.filterwarnings("ignore", category=UserWarning) - cached_ldsr_model: torch.nn.Module = None diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index b947cbec..ccc8344a 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -69,7 +69,6 @@ addEventListener('keydown', (event) => { target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd;
}
- // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
- // internal Svelte data binding remains in sync.
- target.dispatchEvent(new Event("input", { bubbles: true }));
+
+ updateInput(target)
});
diff --git a/javascript/extensions.js b/javascript/extensions.js index 59179ca6..ac6e35b9 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -29,7 +29,7 @@ function install_extension_from_index(button, url){ textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
- textarea.dispatchEvent(new Event("input", { bubbles: true }))
+ updateInput(textarea)
gradioApp().querySelector('#install_extension_button').click()
}
diff --git a/javascript/hints.js b/javascript/hints.js index fa5e5ae8..e746e20d 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -92,6 +92,7 @@ titles = { "Weighted sum": "Result = A * (1 - M) + B * M", "Add difference": "Result = A + (B - C) * M", + "No interpolation": "Result = A", "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors", "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 1f29ad7b..aac2ee82 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -148,7 +148,15 @@ function showGalleryImage() { if(e && e.parentElement.tagName == 'DIV'){ e.style.cursor='pointer' e.style.userSelect='none' - e.addEventListener('mousedown', function (evt) { + + var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1 + + // For Firefox, listening on click first switched to next image then shows the lightbox. + // If you know how to fix this without switching to mousedown event, please. + // For other browsers the event is click to make it possiblr to drag picture. + var event = isFirefox ? 'mousedown' : 'click' + + e.addEventListener(event, function (evt) { if(!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) evt.preventDefault() diff --git a/javascript/localization.js b/javascript/localization.js index f92d2d24..1a5a1dbb 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -10,10 +10,8 @@ ignore_ids_for_localization={ modelmerger_tertiary_model_name: 'OPTION',
train_embedding: 'OPTION',
train_hypernetwork: 'OPTION',
- txt2img_style_index: 'OPTION',
- txt2img_style2_index: 'OPTION',
- img2img_style_index: 'OPTION',
- img2img_style2_index: 'OPTION',
+ txt2img_styles: 'OPTION',
+ img2img_styles: 'OPTION',
setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN',
diff --git a/javascript/progressbar.js b/javascript/progressbar.js index da6709bc..18c771a2 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -106,6 +106,19 @@ function formatTime(secs){ } } +function setTitle(progress){ + var title = 'Stable Diffusion' + + if(opts.show_progress_in_title && progress){ + title = '[' + progress.trim() + '] ' + title; + } + + if(document.title != title){ + document.title = title; + } +} + + function randomId(){ return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" } @@ -117,7 +130,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var dateStart = new Date() var wasEverActive = false var parentProgressbar = progressbarContainer.parentNode - var parentGallery = gallery.parentNode + var parentGallery = gallery ? gallery.parentNode : null var divProgress = document.createElement('div') divProgress.className='progressDiv' @@ -128,13 +141,16 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre divProgress.appendChild(divInner) parentProgressbar.insertBefore(divProgress, progressbarContainer) - var livePreview = document.createElement('div') - livePreview.className='livePreview' - parentGallery.insertBefore(livePreview, gallery) + if(parentGallery){ + var livePreview = document.createElement('div') + livePreview.className='livePreview' + parentGallery.insertBefore(livePreview, gallery) + } var removeProgressBar = function(){ + setTitle("") parentProgressbar.removeChild(divProgress) - parentGallery.removeChild(livePreview) + if(parentGallery) parentGallery.removeChild(livePreview) atEnd() } @@ -154,6 +170,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre progressText = "" divInner.style.width = ((res.progress || 0) * 100.0) + '%' + divInner.style.background = res.progress ? "" : "transparent" if(res.progress > 0){ progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%' @@ -161,8 +178,13 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre if(res.eta){ progressText += " ETA: " + formatTime(res.eta) - } else if(res.textinfo){ - progressText += " " + res.textinfo + } + + + setTitle(progressText) + + if(res.textinfo && res.textinfo.indexOf("\n") == -1){ + progressText = res.textinfo + " " + progressText } divInner.textContent = progressText @@ -182,8 +204,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre } - if(res.live_preview){ - + if(res.live_preview && gallery){ var rect = gallery.getBoundingClientRect() if(rect.width){ livePreview.style.width = rect.width + "px" diff --git a/javascript/ui.js b/javascript/ui.js index ecf97cb3..37788a3e 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -109,6 +109,13 @@ function get_extras_tab_index(){ return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args] } +function get_img2img_tab_index() { + let res = args_to_array(arguments) + res.splice(-2) + res[0] = get_tab_index('mode_img2img') + return res +} + function create_submit_args(args){ res = [] for(var i=0;i<args.length;i++){ @@ -165,6 +172,15 @@ function submit_img2img(){ return res } +function modelmerger(){ + var id = randomId() + requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){}) + + var res = create_submit_args(arguments) + res[0] = id + return res +} + function ask_for_style_name(_, prompt_text, negative_prompt_text) { name_ = prompt('Style name:') @@ -278,3 +294,11 @@ function restart_reload(){ return [] } + +// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits +// will only visible on web page and not sent to python. +function updateInput(target){ + let e = new Event("input", { bubbles: true }) + Object.defineProperty(e, "target", {value: target}) + target.dispatchEvent(e); +} diff --git a/modules/extras.py b/modules/extras.py index 22668fcd..d03f976e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple from functools import partial
from dataclasses import dataclass
-from modules import processing, shared, images, devices, sd_models, sd_samplers
+from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@@ -251,7 +251,8 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c):
def config(x):
- return sd_models.find_checkpoint_config(x) if x else None
+ res = sd_models.find_checkpoint_config(x) if x else None
+ return res if res != shared.sd_default_config else None
if config_source == 0:
cfg = config(a) or config(b) or config(c)
@@ -274,10 +275,25 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename)
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
+chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+
+
+def to_half(tensor, enable):
+ if enable and tensor.dtype == torch.float:
+ return tensor.half()
+
+ return tensor
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
shared.state.begin()
shared.state.job = 'model-merge'
+ def fail(message):
+ shared.state.textinfo = message
+ shared.state.end()
+ return [*[gr.update() for _ in range(4)], message]
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -287,49 +303,89 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
- primary_model_info = sd_models.checkpoints_list[primary_model_name]
- secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
- result_is_inpainting_model = False
+ def filename_weighed_sum():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ Ma = round(1 - multiplier, 2)
+ Mb = round(multiplier, 2)
+
+ return f"{Ma}({a}) + {Mb}({b})"
+
+ def filename_add_differnece():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ c = tertiary_model_info.model_name
+ M = round(multiplier, 2)
+
+ return f"{a} + {M}({b} - {c})"
+
+ def filename_nothing():
+ return primary_model_info.model_name
theta_funcs = {
- "Weighted sum": (None, weighted_sum),
- "Add difference": (get_difference, add_difference),
+ "Weighted sum": (filename_weighed_sum, None, weighted_sum),
+ "Add difference": (filename_add_differnece, get_difference, add_difference),
+ "No interpolation": (filename_nothing, None, None),
}
- theta_func1, theta_func2 = theta_funcs[interp_method]
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
- if theta_func1 and not tertiary_model_info:
- shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
- shared.state.end()
- return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ if not primary_model_name:
+ return fail("Failed: Merging requires a primary model.")
+
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
+
+ if theta_func2 and not secondary_model_name:
+ return fail("Failed: Merging requires a secondary model.")
+
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
+
+ if theta_func1 and not tertiary_model_name:
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
- shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
- print(f"Loading {secondary_model_info.filename}...")
- theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
+
+ result_is_inpainting_model = False
+
+ if theta_func2:
+ shared.state.textinfo = f"Loading B"
+ print(f"Loading {secondary_model_info.filename}...")
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ else:
+ theta_1 = None
if theta_func1:
+ shared.state.textinfo = f"Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+ shared.state.textinfo = 'Merging B and C'
+ shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
+
+ shared.state.sampling_step += 1
del theta_2
+ shared.state.nextjob()
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print("Merging...")
-
- chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
-
+ shared.state.textinfo = 'Merging A and B'
+ shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()):
- if 'model' in key and key in theta_1:
+ if theta_1 and 'model' in key and key in theta_1:
if key in chckpoint_dict_skip_on_merge:
continue
@@ -337,7 +393,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam a = theta_0[key]
b = theta_1[key]
- shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -352,36 +407,39 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam else:
theta_0[key] = theta_func2(a, b, multiplier)
- if save_as_half:
- theta_0[key] = theta_0[key].half()
+ theta_0[key] = to_half(theta_0[key], save_as_half)
- # I believe this part should be discarded, but I'll leave it for now until I am sure
- for key in theta_1.keys():
- if 'model' in key and key not in theta_0:
+ shared.state.sampling_step += 1
- if key in chckpoint_dict_skip_on_merge:
- continue
-
- theta_0[key] = theta_1[key]
- if save_as_half:
- theta_0[key] = theta_0[key].half()
del theta_1
- ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
+ if bake_in_vae_filename is not None:
+ print(f"Baking in VAE from {bake_in_vae_filename}")
+ shared.state.textinfo = 'Baking in VAE'
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
- filename = \
- primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
- secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
- interp_method.replace(" ", "_") + \
- '-merged.' + \
- ("inpainting." if result_is_inpainting_model else "") + \
- checkpoint_format
+ for key in vae_dict.keys():
+ theta_0_key = 'first_stage_model.' + key
+ if theta_0_key in theta_0:
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
+
+ del vae_dict
+
+ if save_as_half and not theta_func2:
+ for key in theta_0.keys():
+ theta_0[key] = to_half(theta_0[key], save_as_half)
+
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
+ filename = filename_generator() if custom_name == '' else custom_name
+ filename += ".inpainting" if result_is_inpainting_model else ""
+ filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
- shared.state.textinfo = f"Saving to {output_modelname}..."
+ shared.state.nextjob()
+ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -394,8 +452,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
- print("Checkpoint saved.")
- shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ print(f"Checkpoint saved to {output_modelname}.")
+ shared.state.textinfo = "Checkpoint saved"
shared.state.end()
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c963fc40..74e78582 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}<br/> pbar.close()
hypernetwork.eval()
#report_statistics(loss_dict)
+ sd_hijack_checkpoint.remove()
+
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
diff --git a/modules/images.py b/modules/images.py index c3a5fc8b..3b1c5f34 100644 --- a/modules/images.py +++ b/modules/images.py @@ -605,8 +605,9 @@ def read_info_from_image(image): except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
- items['exif comment'] = exif_comment
- geninfo = exif_comment
+ if exif_comment:
+ items['exif comment'] = exif_comment
+ geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
diff --git a/modules/img2img.py b/modules/img2img.py index f4a03c57..2168c8e2 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_batch = mode == 5
if mode == 0: # img2img
@@ -101,7 +101,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=prompt,
negative_prompt=negative_prompt,
- styles=[prompt_style, prompt_style2],
+ styles=prompt_styles,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
diff --git a/modules/processing.py b/modules/processing.py index 9c3673de..a3e9f709 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -538,10 +538,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None:
p.scripts.process(p)
- with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
infotexts = []
output_images = []
@@ -572,6 +568,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
if state.job_count == -1:
state.job_count = p.n_iter
@@ -857,7 +857,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob()
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
diff --git a/modules/progress.py b/modules/progress.py index 3327b883..c69ecf3d 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -67,10 +67,13 @@ def progressapi(req: ProgressRequest): progress = 0
- if shared.state.job_count > 0:
- progress += shared.state.job_no / shared.state.job_count
- if shared.state.sampling_steps > 0:
- progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+ job_count, job_no = shared.state.job_count, shared.state.job_no
+ sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
+
+ if job_count > 0:
+ progress += job_no / job_count
+ if sampling_steps > 0 and job_count > 0:
+ progress += 1 / job_count * sampling_step / sampling_steps
progress = min(progress, 1)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 3ac0b97a..47f70251 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -38,13 +38,13 @@ class UpscalerRealESRGAN(Upscaler): return img
info = self.load_model(path)
- if not os.path.exists(info.data_path):
+ if not os.path.exists(info.local_data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
- model_path=info.data_path,
+ model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
@@ -58,17 +58,13 @@ class UpscalerRealESRGAN(Upscaler): def load_model(self, path):
try:
- info = None
- for scaler in self.scalers:
- if scaler.data_path == path:
- info = scaler
+ info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
if info is None:
print(f"Unable to find model info: {path}")
return None
- model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
- info.data_path = model_file
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6b0d95af..870eba88 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -69,12 +69,6 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
-def fix_checkpoint():
- ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|