diff options
-rw-r--r-- | modules/api/api.py | 26 | ||||
-rw-r--r-- | modules/extensions.py | 7 | ||||
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 4 | ||||
-rw-r--r-- | modules/images.py | 2 | ||||
-rw-r--r-- | modules/img2img.py | 4 | ||||
-rw-r--r-- | modules/processing.py | 29 | ||||
-rw-r--r-- | modules/sd_models.py | 10 | ||||
-rw-r--r-- | modules/sd_samplers.py | 13 | ||||
-rw-r--r-- | modules/sd_vae.py | 36 | ||||
-rw-r--r-- | modules/shared.py | 3 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 4 | ||||
-rw-r--r-- | modules/txt2img.py | 3 | ||||
-rw-r--r-- | modules/ui.py | 2 | ||||
-rw-r--r-- | modules/ui_extensions.py | 4 | ||||
-rw-r--r-- | requirements.txt | 1 | ||||
-rw-r--r-- | requirements_versions.txt | 1 | ||||
-rw-r--r-- | scripts/img2imgalt.py | 4 | ||||
-rw-r--r-- | scripts/xy_grid.py | 12 | ||||
-rw-r--r-- | webui-user.bat | 1 | ||||
-rw-r--r-- | webui-user.sh | 3 | ||||
-rw-r--r-- | webui.bat | 12 | ||||
-rw-r--r-- | webui.py | 1 | ||||
-rwxr-xr-x | webui.sh | 16 |
23 files changed, 113 insertions, 85 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 596a6616..0eccccbb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -6,9 +6,9 @@ from threading import Lock from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException import modules.shared as shared +from modules import sd_samplers from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo from PIL import PngImagePlugin from modules.sd_models import checkpoints_list @@ -25,8 +25,12 @@ def upscaler_to_index(name: str): raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") -sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def validate_sampler_name(name): + config = sd_samplers.all_samplers_map.get(name, None) + if config is None: + raise HTTPException(status_code=404, detail="Sampler not found") + return name def setUpscalers(req: dict): reqDict = vars(req) @@ -82,14 +86,9 @@ class Api: self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - sampler_index = sampler_to_index(txt2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True } @@ -109,12 +108,6 @@ class Api: return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): - sampler_index = sampler_to_index(img2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - - init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -123,10 +116,9 @@ class Api: if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, "mask": mask @@ -272,7 +264,7 @@ class Api: return vars(shared.cmd_opts) def get_samplers(self): - return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] + return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): upscalers = [] diff --git a/modules/extensions.py b/modules/extensions.py index 94ce479a..db9c4200 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -65,9 +65,12 @@ class Extension: self.can_update = False
self.status = "latest"
- def pull(self):
+ def fetch_and_reset_hard(self):
repo = git.Repo(self.path)
- repo.remotes.origin.pull()
+ # Fix: `error: Your local changes to the following files would be overwritten by merge`,
+ # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
+ repo.git.fetch('--all')
+ repo.git.reset('--hard', 'origin')
def list_extensions():
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7f182712..fbb87dd1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared
+from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
- p.sampler_index = preview_sampler_index
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
diff --git a/modules/images.py b/modules/images.py index ae705cbd..26d5b7a9 100644 --- a/modules/images.py +++ b/modules/images.py @@ -303,7 +303,7 @@ class FilenameGenerator: 'width': lambda self: self.image.width,
'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
- 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
+ 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
diff --git a/modules/img2img.py b/modules/img2img.py index be9f3653..9fc5b693 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -6,7 +6,7 @@ import traceback import numpy as np
from PIL import Image, ImageOps, ImageChops
-from modules import devices
+from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
@@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
- sampler_index=sampler_index,
+ sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
diff --git a/modules/processing.py b/modules/processing.py index 2fc9fe13..fb30aa81 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -2,6 +2,7 @@ import json import math
import os
import sys
+import warnings
import torch
import numpy as np
@@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays): return image
-def get_correct_sampler(p):
- if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
- return sd_samplers.samplers
- elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
- return sd_samplers.samplers_for_img2img
- elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
- return sd_samplers.samplers
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
+ if sampler_index is not None:
+ warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name")
+
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -91,7 +88,7 @@ class StableDiffusionProcessing(): self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w
- self.sampler_index: int = sampler_index
+ self.sampler_name: str = sampler_name
self.batch_size: int = batch_size
self.n_iter: int = n_iter
self.steps: int = steps
@@ -210,8 +207,7 @@ class Processed: self.info = info
self.width = p.width
self.height = p.height
- self.sampler_index = p.sampler_index
- self.sampler = sd_samplers.samplers[p.sampler_index].name
+ self.sampler_name = p.sampler_name
self.cfg_scale = p.cfg_scale
self.steps = p.steps
self.batch_size = p.batch_size
@@ -256,8 +252,7 @@ class Processed: "subseed_strength": self.subseed_strength,
"width": self.width,
"height": self.height,
- "sampler_index": self.sampler_index,
- "sampler": self.sampler,
+ "sampler_name": self.sampler_name,
"cfg_scale": self.cfg_scale,
"steps": self.steps,
"batch_size": self.batch_size,
@@ -384,7 +379,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = {
"Steps": p.steps,
- "Sampler": get_correct_sampler(p)[p.sampler_index].name,
+ "Sampler": p.sampler_name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
@@ -646,7 +641,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@@ -707,7 +702,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob()
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@@ -744,7 +739,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None
if self.image_mask is not None:
diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..c59151e0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): cache_enabled = shared.opts.sd_checkpoint_cache > 0
- if cache_enabled:
- sd_vae.restore_base_vae(model)
-
- vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
-
if cache_enabled and checkpoint_info in checkpoints_loaded:
# use checkpoint cache
- vae_name = sd_vae.get_filename(vae_file) if vae_file else None
- vae_message = f" with {vae_name} VAE" if vae_name else ""
- print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
+ print(f"Loading weights [{sd_model_hash}] from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
@@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
+ vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
sd_vae.load_vae(model, vae_file)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 783992d2..4fe67854 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -46,16 +46,23 @@ all_samplers = [ SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
+all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
-def create_sampler_with_index(list_of_configs, index, model):
- config = list_of_configs[index]
+def create_sampler(name, model):
+ if name is not None:
+ config = all_samplers_map.get(name, None)
+ else:
+ config = all_samplers[0]
+
+ assert config is not None, f'bad sampler name: {name}'
+
sampler = config.constructor(model)
sampler.config = config
-
+
return sampler
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..9c120975 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -83,47 +83,54 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): return vae_list -def resolve_vae(checkpoint_file, vae_file="auto"): +def get_vae_from_settings(vae_file="auto"): + # else, we load from settings, if not set to be default + if vae_file == "auto" and shared.opts.sd_vae is not None: + # if saved VAE settings isn't recognized, fallback to auto + vae_file = vae_dict.get(shared.opts.sd_vae, "auto") + # if VAE selected but not found, fallback to auto + if vae_file not in default_vae_values and not os.path.isfile(vae_file): + vae_file = "auto" + print(f"Selected VAE doesn't exist: {vae_file}") + return vae_file + + +def resolve_vae(checkpoint_file=None, vae_file="auto"): global first_load, vae_dict, vae_list # if vae_file argument is provided, it takes priority, but not saved if vae_file and vae_file not in default_vae_list: if not os.path.isfile(vae_file): + print(f"VAE provided as function argument doesn't exist: {vae_file}") vae_file = "auto" - print("VAE provided as function argument doesn't exist") # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported if first_load and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): vae_file = shared.cmd_opts.vae_path shared.opts.data['sd_vae'] = get_filename(vae_file) else: - print("VAE provided as command line argument doesn't exist") - # else, we load from settings - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print("Selected VAE doesn't exist") + print(f"VAE provided as command line argument doesn't exist: {vae_file}") + # fallback to selector in settings, if vae selector not set to act as default fallback + if not shared.opts.sd_vae_as_default: + vae_file = get_vae_from_settings(vae_file) # vae-path cmd arg takes priority for auto if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): vae_file = shared.cmd_opts.vae_path - print("Using VAE provided as command line argument") + print(f"Using VAE provided as command line argument: {vae_file}") # if still not found, try look for ".vae.pt" beside model model_path = os.path.splitext(checkpoint_file)[0] if vae_file == "auto": vae_file_try = model_path + ".vae.pt" if os.path.isfile(vae_file_try): vae_file = vae_file_try - print("Using VAE found beside selected model") + print(f"Using VAE found similar to selected model: {vae_file}") # if still not found, try look for ".vae.ckpt" beside model if vae_file == "auto": vae_file_try = model_path + ".vae.ckpt" if os.path.isfile(vae_file_try): vae_file = vae_file_try - print("Using VAE found beside selected model") + print(f"Using VAE found similar to selected model: {vae_file}") # No more fallbacks for auto if vae_file == "auto": vae_file = None @@ -139,6 +146,7 @@ def load_vae(model, vae_file=None): # save_settings = False if vae_file: + assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} diff --git a/modules/shared.py b/modules/shared.py index 6936cbe0..84567c8e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -334,7 +334,8 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
+ "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
+ "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 0aeb0459..5e4d8688 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -10,7 +10,7 @@ import csv from PIL import Image, PngImagePlugin
-from modules import shared, devices, sd_hijack, processing, sd_models, images
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -345,7 +345,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
- p.sampler_index = preview_sampler_index
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
diff --git a/modules/txt2img.py b/modules/txt2img.py index 8e4e8677..c8f81176 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,4 +1,5 @@ import modules.scripts
+from modules import sd_samplers
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
@@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
- sampler_index=sampler_index,
+ sampler_name=sd_samplers.samplers[sampler_index].name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
diff --git a/modules/ui.py b/modules/ui.py index 5dce7f3b..2d488741 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -142,7 +142,7 @@ def save_files(js_data, images, do_make_zip, index): filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn)
- writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
# Make Zip
if do_make_zip:
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 6671cb60..030f011e 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list): continue
try:
- ext.pull()
+ ext.fetch_and_reset_hard()
except Exception:
- print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
+ print(f"Error getting updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled
diff --git a/requirements.txt b/requirements.txt index 0fba0b23..762db4f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +accelerate
basicsr
diffusers
fairscale==0.4.4
diff --git a/requirements_versions.txt b/requirements_versions.txt index f7059f20..662ca684 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,5 +1,6 @@ transformers==4.19.2
diffusers==0.3.0
+accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.9
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 964b75c7..1229f61b 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -157,7 +157,7 @@ class Script(scripts.Script): def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
# Override
if override_sampler:
- p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler")
+ p.sampler_name = "Euler"
if override_prompt:
p.prompt = original_prompt
p.negative_prompt = original_negative_prompt
@@ -191,7 +191,7 @@ class Script(scripts.Script): combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
- sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model)
+ sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps)
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 417ed0d4..b0b9d84d 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,9 +10,9 @@ import numpy as np import modules.scripts as scripts
import gradio as gr
-from modules import images
+from modules import images, sd_samplers
from modules.hypernetworks import hypernetwork
-from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
+from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.sd_samplers
@@ -60,9 +60,9 @@ def apply_order(p, x, xs): p.prompt = prompt_tmp + p.prompt
-def build_samplers_dict(p):
+def build_samplers_dict():
samplers_dict = {}
- for i, sampler in enumerate(get_correct_sampler(p)):
+ for i, sampler in enumerate(sd_samplers.all_samplers):
samplers_dict[sampler.name.lower()] = i
for alias in sampler.aliases:
samplers_dict[alias.lower()] = i
@@ -70,7 +70,7 @@ def build_samplers_dict(p): def apply_sampler(p, x, xs):
- sampler_index = build_samplers_dict(p).get(x.lower(), None)
+ sampler_index = build_samplers_dict().get(x.lower(), None)
if sampler_index is None:
raise RuntimeError(f"Unknown sampler: {x}")
@@ -78,7 +78,7 @@ def apply_sampler(p, x, xs): def confirm_samplers(p, xs):
- samplers_dict = build_samplers_dict(p)
+ samplers_dict = build_samplers_dict()
for x in xs:
if x.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {x}")
diff --git a/webui-user.bat b/webui-user.bat index e5a257be..0c120111 100644 --- a/webui-user.bat +++ b/webui-user.bat @@ -4,5 +4,6 @@ set PYTHON= set GIT= set VENV_DIR= set COMMANDLINE_ARGS= +set ACCELERATE= call webui.bat diff --git a/webui-user.sh b/webui-user.sh index 30646f5c..16e42759 100644 --- a/webui-user.sh +++ b/webui-user.sh @@ -40,4 +40,7 @@ export COMMANDLINE_ARGS="" #export CODEFORMER_COMMIT_HASH="" #export BLIP_COMMIT_HASH="" +# Uncomment to enable accelerated launch +#export ACCELERATE="True" + ########################################### @@ -28,15 +28,27 @@ goto :show_stdout_stderr :activate_venv
set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
echo venv %PYTHON%
+if [%ACCELERATE%] == ["True"] goto :accelerate
goto :launch
:skip_venv
+:accelerate
+echo "Checking for accelerate"
+set ACCELERATE="%~dp0%VENV_DIR%\Scripts\accelerate.exe"
+if EXIST %ACCELERATE% goto :accelerate_launch
+
:launch
%PYTHON% launch.py %*
pause
exit /b
+:accelerate_launch
+echo "Accelerating"
+%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
+pause
+exit /b
+
:show_stdout_stderr
echo.
@@ -82,6 +82,7 @@ def initialize(): modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
@@ -134,7 +134,15 @@ else exit 1 fi -printf "\n%s\n" "${delimiter}" -printf "Launching launch.py..." -printf "\n%s\n" "${delimiter}" -"${python_cmd}" "${LAUNCH_SCRIPT}" "$@" +if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ] +then + printf "\n%s\n" "${delimiter}" + printf "Accelerating launch.py..." + printf "\n%s\n" "${delimiter}" + accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" +else + printf "\n%s\n" "${delimiter}" + printf "Launching launch.py..." + printf "\n%s\n" "${delimiter}" |