diff options
author | unknown <mcgpapu@gmail.com> | 2022-12-12 15:12:26 +0000 |
---|---|---|
committer | unknown <mcgpapu@gmail.com> | 2022-12-12 15:12:26 +0000 |
commit | d6fdfde9d70f1b86b696240fb0a0c8f2a4d024f6 (patch) | |
tree | 991a5b9c6c5bdd15bb9ff18bb68bd7c0df87bbfc | |
parent | 4005cd66e08d262b289d8d4a31fd425f260bcd11 (diff) | |
parent | 685f9631b56ff8bd43bce24ff5ce0f9a0e9af490 (diff) | |
download | stable-diffusion-webui-gfx803-d6fdfde9d70f1b86b696240fb0a0c8f2a4d024f6.tar.gz stable-diffusion-webui-gfx803-d6fdfde9d70f1b86b696240fb0a0c8f2a4d024f6.tar.bz2 stable-diffusion-webui-gfx803-d6fdfde9d70f1b86b696240fb0a0c8f2a4d024f6.zip |
Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui
-rw-r--r-- | extensions-builtin/LDSR/ldsr_model_arch.py | 49 | ||||
-rw-r--r-- | extensions-builtin/LDSR/scripts/ldsr_model.py | 1 | ||||
-rw-r--r-- | extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js | 2 | ||||
-rw-r--r-- | modules/processing.py | 17 | ||||
-rw-r--r-- | modules/safety.py | 42 | ||||
-rw-r--r-- | modules/scripts.py | 20 | ||||
-rw-r--r-- | modules/shared.py | 1 | ||||
-rw-r--r-- | modules/ui_extensions.py | 13 | ||||
-rw-r--r-- | requirements.txt | 1 | ||||
-rw-r--r-- | requirements_versions.txt | 1 | ||||
-rw-r--r-- | script.js | 7 |
11 files changed, 80 insertions, 74 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index a87d1ef9..8b048ae0 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -11,25 +11,41 @@ from omegaconf import OmegaConf from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap +from modules import shared, sd_hijack warnings.filterwarnings("ignore", category=UserWarning) +cached_ldsr_model: torch.nn.Module = None + # Create LDSR Class class LDSR: def load_model_from_config(self, half_attention): - print(f"Loading model from {self.modelPath}") - pl_sd = torch.load(self.modelPath, map_location="cpu") - sd = pl_sd["state_dict"] - config = OmegaConf.load(self.yamlPath) - config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" - model = instantiate_from_config(config.model) - model.load_state_dict(sd, strict=False) - model.cuda() - if half_attention: - model = model.half() - - model.eval() + global cached_ldsr_model + + if shared.opts.ldsr_cached and cached_ldsr_model is not None: + print(f"Loading model from cache") + model: torch.nn.Module = cached_ldsr_model + else: + print(f"Loading model from {self.modelPath}") + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] + config = OmegaConf.load(self.yamlPath) + config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" + model: torch.nn.Module = instantiate_from_config(config.model) + model.load_state_dict(sd, strict=False) + model = model.to(shared.device) + if half_attention: + model = model.half() + if shared.cmd_opts.opt_channelslast: + model = model.to(memory_format=torch.channels_last) + + sd_hijack.model_hijack.hijack(model) # apply optimization + model.eval() + + if shared.opts.ldsr_cached: + cached_ldsr_model = model + return {"model": model} def __init__(self, model_path, yaml_path): @@ -94,7 +110,8 @@ class LDSR: down_sample_method = 'Lanczos' gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available: + torch.cuda.empty_cache() im_og = image width_og, height_og = im_og.size @@ -131,7 +148,9 @@ class LDSR: del model gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available: + torch.cuda.empty_cache() + return a @@ -146,7 +165,7 @@ def get_cond(selected_path): c = rearrange(c, '1 c h w -> 1 h w c') c = 2. * c - 1. - c = c.to(torch.device("cuda")) + c = c.to(shared.device) example["LR_image"] = c example["image"] = c_up diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index 5c96037d..29d5f94e 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -59,6 +59,7 @@ def on_ui_settings(): import gradio as gr shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) + shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling"))) script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index 3f3bebcd..eccfb0f9 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -88,7 +88,7 @@ function checkBrackets(evt) { if(counterElt.title != '') { counterElt.style = 'color: #FF5555;'; } else { - counterElt.style = 'color: #000;'; + counterElt.style = ''; } } diff --git a/modules/processing.py b/modules/processing.py index 81400d14..24c537d1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,13 +13,15 @@ from skimage import exposure from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.face_restoration
import modules.images as images
import modules.styles
+import modules.sd_models as sd_models
+import modules.sd_vae as sd_vae
import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
@@ -454,8 +456,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try:
for k, v in p.override_settings.items():
- setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible
- if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not
+ setattr(opts, k, v)
+ if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
+ if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
+ if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
res = process_images_inner(p)
@@ -463,6 +467,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: for k, v in stored_opts.items():
setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
+ if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
+ if k == 'sd_vae': sd_vae.reload_vae_weights()
return res
@@ -571,9 +577,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc()
- if opts.filter_nsfw:
- import modules.safety as safety
- x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
+ if p.scripts is not None:
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
diff --git a/modules/safety.py b/modules/safety.py deleted file mode 100644 index cff4b278..00000000 --- a/modules/safety.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from transformers import AutoFeatureExtractor
-from PIL import Image
-
-import modules.shared as shared
-
-safety_model_id = "CompVis/stable-diffusion-safety-checker"
-safety_feature_extractor = None
-safety_checker = None
-
-def numpy_to_pil(images):
- """
- Convert a numpy image or a batch of images to a PIL image.
- """
- if images.ndim == 3:
- images = images[None, ...]
- images = (images * 255).round().astype("uint8")
- pil_images = [Image.fromarray(image) for image in images]
-
- return pil_images
-
-# check and replace nsfw content
-def check_safety(x_image):
- global safety_feature_extractor, safety_checker
-
- if safety_feature_extractor is None:
- safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
-
- safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
- x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
-
- return x_checked_image, has_nsfw_concept
-
-
-def censor_batch(x):
- x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
- x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
- x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
-
- return x
diff --git a/modules/scripts.py b/modules/scripts.py index b934d881..23ca195d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -88,6 +88,17 @@ class Script: pass
+ def postprocess_batch(self, p, *args, **kwargs):
+ """
+ Same as process_batch(), but called for every batch after it has been generated.
+
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - images - torch tensor with all generated images, with values ranging from 0 to 1;
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -347,6 +358,15 @@ class ScriptRunner: print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def postprocess_batch(self, p, images, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_batch(p, *script_args, images=images, **kwargs)
+ except Exception:
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
diff --git a/modules/shared.py b/modules/shared.py index 44922c91..272267c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -367,7 +367,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index b487ac25..1434f25f 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -206,12 +206,13 @@ def refresh_available_extensions_from_data(hide_tags): if url is None:
continue
+ existing = installed_extension_urls.get(normalize_git_url(url), None)
+ extension_tags = extension_tags + ["installed"] if existing else extension_tags
+
if len([x for x in extension_tags if x in tags_to_hide]) > 0:
hidden += 1
continue
- existing = installed_extension_urls.get(normalize_git_url(url), None)
-
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
@@ -222,7 +223,11 @@ def refresh_available_extensions_from_data(hide_tags): <td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>
- """
+
+ """
+
+ for tag in [x for x in extension_tags if x not in tags]:
+ tags[tag] = tag
code += """
</tbody>
@@ -272,7 +277,7 @@ def create_ui(): install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
with gr.Row():
- hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
+ hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
install_result = gr.HTML()
available_extensions_table = gr.HTML()
diff --git a/requirements.txt b/requirements.txt index 05818aa6..678acb4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ accelerate
basicsr
-diffusers
fairscale==0.4.4
fonts
font-roboto
diff --git a/requirements_versions.txt b/requirements_versions.txt index 035fa82f..185cd066 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,5 +1,4 @@ transformers==4.19.2
-diffusers==0.3.0
accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8
@@ -1,5 +1,6 @@ -function gradioApp(){ - return document.getElementsByTagName('gradio-app')[0].shadowRoot; +function gradioApp() { + const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot + return !!gradioShadowRoot ? gradioShadowRoot : document; } function get_uiCurrentTab() { @@ -82,4 +83,4 @@ function uiElementIsVisible(el) { } } return isVisible; -}
\ No newline at end of file +} |