diff options
author | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-07 15:43:38 +0000 |
---|---|---|
committer | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-07 15:43:38 +0000 |
commit | cabd4e3b3bf91e0cb5071398a8efddef495f6311 (patch) | |
tree | 55daa888a7e03e2e204daf6729835b94277350a2 /modules | |
parent | bb832d7725187f8a8ab44faa6ee1b38cb5f600aa (diff) | |
parent | 804d9fb83d0c63ca3acd36378707ce47b8f12599 (diff) | |
download | stable-diffusion-webui-gfx803-cabd4e3b3bf91e0cb5071398a8efddef495f6311.tar.gz stable-diffusion-webui-gfx803-cabd4e3b3bf91e0cb5071398a8efddef495f6311.tar.bz2 stable-diffusion-webui-gfx803-cabd4e3b3bf91e0cb5071398a8efddef495f6311.zip |
Merge branch 'master' into gradient-clipping
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 53 | ||||
-rw-r--r-- | modules/api/models.py | 40 | ||||
-rw-r--r-- | modules/extensions.py | 7 | ||||
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 59 | ||||
-rw-r--r-- | modules/hypernetworks/ui.py | 2 | ||||
-rw-r--r-- | modules/ldsr_model_arch.py | 14 | ||||
-rw-r--r-- | modules/localization.py | 6 | ||||
-rw-r--r-- | modules/safe.py | 40 | ||||
-rw-r--r-- | modules/script_callbacks.py | 69 | ||||
-rw-r--r-- | modules/scripts.py | 1 | ||||
-rw-r--r-- | modules/sd_samplers.py | 4 | ||||
-rw-r--r-- | modules/shared.py | 16 | ||||
-rw-r--r-- | modules/ui.py | 62 | ||||
-rw-r--r-- | modules/ui_extensions.py | 55 | ||||
-rw-r--r-- | modules/upscaler.py | 12 |
15 files changed, 326 insertions, 114 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index a49f3755..688469ad 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -10,6 +10,7 @@ from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo +from PIL import PngImagePlugin from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List @@ -34,9 +35,21 @@ def setUpscalers(req: dict): def encode_pil_to_base64(image): - buffer = io.BytesIO() - image.save(buffer, format="png") - return base64.b64encode(buffer.getvalue()) + with io.BytesIO() as output_bytes: + + # Copy any text-only metadata + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + image.save( + output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) + ) + bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) class Api: @@ -50,6 +63,7 @@ class Api: self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) + self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) @@ -201,11 +215,24 @@ class Api: return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) + def interrogateapi(self, interrogatereq: InterrogateRequest): + image_b64 = interrogatereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + img = self.__base64_to_image(image_b64) + + # Override object param + with self.queue_lock: + processed = shared.interrogator.interrogate(img) + + return InterrogateResponse(caption=processed) + def interruptapi(self): shared.state.interrupt() return {} - + def get_config(self): options = {} for key in shared.opts.data.keys(): @@ -214,10 +241,14 @@ class Api: options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)}) else: options.update({key: shared.opts.data.get(key, None)}) - + return options - + def set_config(self, req: OptionsModel): + # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will + # overwrite all options with default values. + raise RuntimeError('Setting options via API is not supported') + reqDict = vars(req) for o in reqDict: setattr(shared.opts, o, reqDict[o]) @@ -233,13 +264,13 @@ class Api: def get_upscalers(self): upscalers = [] - + for upscaler in shared.sd_upscalers: u = upscaler.scaler upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url}) - + return upscalers - + def get_sd_models(self): return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()] @@ -251,11 +282,11 @@ class Api: def get_realesrgan_models(self): return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)] - + def get_promp_styles(self): styleList = [] for k in shared.prompt_styles.styles: - style = shared.prompt_styles.styles[k] + style = shared.prompt_styles.styles[k] styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]}) return styleList diff --git a/modules/api/models.py b/modules/api/models.py index 2ae75f43..34dbfa16 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,11 +1,11 @@ import inspect from pydantic import BaseModel, Field, create_model -from typing import Any, Optional, Union +from typing import Any, Optional from typing_extensions import Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers, opts, parser -from typing import List +from typing import Dict, List API_NOT_ALLOWED = [ "self", @@ -65,6 +65,7 @@ class PydanticModelGenerator: self._model_name = model_name self._class_data = merge_class_params(class_instance) + self._model_def = [ ModelDef( field=underscore(k), @@ -167,6 +168,12 @@ class ProgressResponse(BaseModel): state: dict = Field(title="State", description="The current state snapshot") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") +class InterrogateRequest(BaseModel): + image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + +class InterrogateResponse(BaseModel): + caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") + fields = {} for key, value in opts.data.items(): metadata = opts.data_labels.get(key) @@ -185,22 +192,22 @@ _options = vars(parser)['_option_string_actions'] for key in _options: if(_options[key].dest != 'help'): flag = _options[key] - _type = str - if(_options[key].default != None): _type = type(_options[key].default) + _type = str + if _options[key].default is not None: _type = type(_options[key].default) flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) FlagsModel = create_model("Flags", **flags) class SamplerItem(BaseModel): name: str = Field(title="Name") - aliases: list[str] = Field(title="Aliases") - options: dict[str, str] = Field(title="Options") + aliases: List[str] = Field(title="Aliases") + options: Dict[str, str] = Field(title="Options") class UpscalerItem(BaseModel): name: str = Field(title="Name") - model_name: str | None = Field(title="Model Name") - model_path: str | None = Field(title="Path") - model_url: str | None = Field(title="URL") + model_name: Optional[str] = Field(title="Model Name") + model_path: Optional[str] = Field(title="Path") + model_url: Optional[str] = Field(title="URL") class SDModelItem(BaseModel): title: str = Field(title="Title") @@ -211,23 +218,24 @@ class SDModelItem(BaseModel): class HypernetworkItem(BaseModel): name: str = Field(title="Name") - path: str | None = Field(title="Path") + path: Optional[str] = Field(title="Path") class FaceRestorerItem(BaseModel): name: str = Field(title="Name") - cmd_dir: str | None = Field(title="Path") + cmd_dir: Optional[str] = Field(title="Path") class RealesrganItem(BaseModel): name: str = Field(title="Name") - path: str | None = Field(title="Path") - scale: int | None = Field(title="Scale") + path: Optional[str] = Field(title="Path") + scale: Optional[int] = Field(title="Scale") class PromptStyleItem(BaseModel): name: str = Field(title="Name") - prompt: str | None = Field(title="Prompt") - negative_prompt: str | None = Field(title="Negative Prompt") + prompt: Optional[str] = Field(title="Prompt") + negative_prompt: Optional[str] = Field(title="Negative Prompt") class ArtistItem(BaseModel): name: str = Field(title="Name") score: float = Field(title="Score") - category: str = Field(title="Category")
\ No newline at end of file + category: str = Field(title="Category") + diff --git a/modules/extensions.py b/modules/extensions.py index 897af96e..8e0977fd 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -34,8 +34,11 @@ class Extension: if repo is None or repo.bare:
self.remote = None
else:
- self.remote = next(repo.remote().urls, None)
- self.status = 'unknown'
+ try:
+ self.remote = next(repo.remote().urls, None)
+ self.status = 'unknown'
+ except Exception:
+ self.remote = None
def list_files(self, subdir, extension):
from modules import scripts
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 02b624e1..3371b18e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -22,6 +22,8 @@ from collections import defaultdict, deque from statistics import stdev, mean
+optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -142,6 +144,8 @@ class Hypernetwork: self.use_dropout = use_dropout
self.activate_output = activate_output
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
+ self.optimizer_name = None
+ self.optimizer_state_dict = None
for size in enable_sizes or []:
self.layers[size] = (
@@ -163,6 +167,7 @@ class Hypernetwork: def save(self, filename):
state_dict = {}
+ optimizer_saved_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -178,8 +183,15 @@ class Hypernetwork: state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
state_dict['last_layer_dropout'] = self.last_layer_dropout
-
+
+ if self.optimizer_name is not None:
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
+
torch.save(state_dict, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict:
+ optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
+ torch.save(optimizer_saved_dict, filename + '.optim')
def load(self, filename):
self.filename = filename
@@ -202,6 +214,18 @@ class Hypernetwork: print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
+ print(f"Optimizer name is {self.optimizer_name}")
+ if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ else:
+ self.optimizer_state_dict = None
+ if self.optimizer_state_dict:
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
@@ -219,11 +243,11 @@ class Hypernetwork: def list_hypernetworks(path):
res = {}
- for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name] = filename
+ res[name + f"({sd_models.model_hash(filename)})"] = filename
return res
@@ -358,6 +382,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
@@ -410,8 +435,22 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
- # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
- optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+
+ # Here we use optimizer from saved HN, or we can specify as UI option.
+ if hypernetwork.optimizer_name in optimizer_dict:
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
+ optimizer_name = hypernetwork.optimizer_name
+ else:
+ print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
+ optimizer_name = 'AdamW'
+
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
+ try:
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
+ except RuntimeError as e:
+ print("Cannot resume from saved optimizer!")
+ print(e)
steps_without_grad = 0
@@ -479,7 +518,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
@@ -542,8 +585,12 @@ Last saved image: {html.escape(last_saved_image)}<br/> report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
-
+ del optimizer
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index aad09ffc..c2d4b51c 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"]
-keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py index 14db5076..90e0a2f0 100644 --- a/modules/ldsr_model_arch.py +++ b/modules/ldsr_model_arch.py @@ -101,8 +101,8 @@ class LDSR: down_sample_rate = target_scale / 4 wd = width_og * down_sample_rate hd = height_og * down_sample_rate - width_downsampled_pre = int(wd) - height_downsampled_pre = int(hd) + width_downsampled_pre = int(np.ceil(wd)) + height_downsampled_pre = int(np.ceil(hd)) if down_sample_rate != 1: print( @@ -110,7 +110,12 @@ class LDSR: im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) else: print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") - logs = self.run(model["model"], im_og, diffusion_steps, eta) + + # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts + pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size + im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) + + logs = self.run(model["model"], im_padded, diffusion_steps, eta) sample = logs["sample"] sample = sample.detach().cpu() @@ -120,6 +125,9 @@ class LDSR: sample = np.transpose(sample, (0, 2, 3, 1)) a = Image.fromarray(sample[0]) + # remove padding + a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4)) + del model gc.collect() torch.cuda.empty_cache() diff --git a/modules/localization.py b/modules/localization.py index b1810cda..f6a6f2fb 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -3,6 +3,7 @@ import os import sys
import traceback
+
localizations = {}
@@ -16,6 +17,11 @@ def list_localizations(dirname): localizations[fn] = os.path.join(dirname, file)
+ from modules import scripts
+ for file in scripts.list_scripts("localizations", ".json"):
+ fn, ext = os.path.splitext(file.filename)
+ localizations[fn] = file.path
+
def localization_js(current_localization_name):
fn = localizations.get(current_localization_name, None)
diff --git a/modules/safe.py b/modules/safe.py index 348a24fc..a9209e38 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -23,11 +23,18 @@ def encode(*args): class RestrictedUnpickler(pickle.Unpickler):
+ extra_handler = None
+
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return TypedStorage()
def find_class(self, module, name):
+ if self.extra_handler is not None:
+ res = self.extra_handler(module, name)
+ if res is not None:
+ return res
+
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
@@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler): return set
# Forbid everything else.
- raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+ raise Exception(f"global '{module}/{name}' is forbidden")
allowed_zip_names = ["archive/data.pkl", "archive/version"]
@@ -69,7 +76,7 @@ def check_zip_filenames(filename, names): raise Exception(f"bad file inside {filename}: {name}")
-def check_pt(filename):
+def check_pt(filename, extra_handler):
try:
# new pytorch format is a zip file
@@ -78,6 +85,7 @@ def check_pt(filename): with z.open('archive/data.pkl') as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
unpickler.load()
except zipfile.BadZipfile:
@@ -85,16 +93,42 @@ def check_pt(filename): # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
for i in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
+ return load_with_extra(filename, *args, **kwargs)
+
+
+def load_with_extra(filename, extra_handler=None, *args, **kwargs):
+ """
+ this functon is intended to be used by extensions that want to load models with
+ some extra classes in them that the usual unpickler would find suspicious.
+
+ Use the extra_handler argument to specify a function that takes module and field name as text,
+ and returns that field's value:
+
+ ```python
+ def extra(module, name):
+ if module == 'collections' and name == 'OrderedDict':
+ return collections.OrderedDict
+
+ return None
+
+ safe.load_with_extra('model.pt', extra_handler=extra)
+ ```
+
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
+ definitely unsafe.
+ """
+
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
- check_pt(filename)
+ check_pt(filename, extra_handler)
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index c28e220e..74dfb880 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -46,25 +46,23 @@ class CFGDenoiserParams: ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
-callbacks_app_started = []
-callbacks_model_loaded = []
-callbacks_ui_tabs = []
-callbacks_ui_settings = []
-callbacks_before_image_saved = []
-callbacks_image_saved = []
-callbacks_cfg_denoiser = []
+callback_map = dict(
+ callbacks_app_started=[],
+ callbacks_model_loaded=[],
+ callbacks_ui_tabs=[],
+ callbacks_ui_settings=[],
+ callbacks_before_image_saved=[],
+ callbacks_image_saved=[],
+ callbacks_cfg_denoiser=[]
+)
def clear_callbacks():
- callbacks_model_loaded.clear()
- callbacks_ui_tabs.clear()
- callbacks_ui_settings.clear()
- callbacks_before_image_saved.clear()
- callbacks_image_saved.clear()
- callbacks_cfg_denoiser.clear()
+ for callback_list in callback_map.values():
+ callback_list.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
- for c in callbacks_app_started:
+ for c in callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
except Exception:
@@ -72,7 +70,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): def model_loaded_callback(sd_model):
- for c in callbacks_model_loaded:
+ for c in callback_map['callbacks_model_loaded']:
try:
c.callback(sd_model)
except Exception:
@@ -82,7 +80,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback():
res = []
- for c in callbacks_ui_tabs:
+ for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
except Exception:
@@ -92,7 +90,7 @@ def ui_tabs_callback(): def ui_settings_callback():
- for c in callbacks_ui_settings:
+ for c in callback_map['callbacks_ui_settings']:
try:
c.callback()
except Exception:
@@ -100,7 +98,7 @@ def ui_settings_callback(): def before_image_saved_callback(params: ImageSaveParams):
- for c in callbacks_before_image_saved:
+ for c in callback_map['callbacks_before_image_saved']:
try:
c.callback(params)
except Exception:
@@ -108,7 +106,7 @@ def before_image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams):
- for c in callbacks_image_saved:
+ for c in callback_map['callbacks_image_saved']:
try:
c.callback(params)
except Exception:
@@ -116,7 +114,7 @@ def image_saved_callback(params: ImageSaveParams): def cfg_denoiser_callback(params: CFGDenoiserParams):
- for c in callbacks_cfg_denoiser:
+ for c in callback_map['callbacks_cfg_denoiser']:
try:
c.callback(params)
except Exception:
@@ -129,17 +127,33 @@ def add_callback(callbacks, fun): callbacks.append(ScriptCallback(filename, fun))
+
+def remove_current_script_callbacks():
+ stack = [x for x in inspect.stack() if x.filename != __file__]
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ if filename == 'unknown file':
+ return
+ for callback_list in callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
+ callback_list.remove(callback_to_remove)
+
+
+def remove_callbacks_for_function(callback_func):
+ for callback_list in callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
+ callback_list.remove(callback_to_remove)
+
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
- add_callback(callbacks_app_started, callback)
+ add_callback(callback_map['callbacks_app_started'], callback)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
- add_callback(callbacks_model_loaded, callback)
+ add_callback(callback_map['callbacks_model_loaded'], callback)
def on_ui_tabs(callback):
@@ -152,13 +166,13 @@ def on_ui_tabs(callback): title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
- add_callback(callbacks_ui_tabs, callback)
+ add_callback(callback_map['callbacks_ui_tabs'], callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
- add_callback(callbacks_ui_settings, callback)
+ add_callback(callback_map['callbacks_ui_settings'], callback)
def on_before_image_saved(callback):
@@ -166,7 +180,7 @@ def on_before_image_saved(callback): The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
"""
- add_callback(callbacks_before_image_saved, callback)
+ add_callback(callback_map['callbacks_before_image_saved'], callback)
def on_image_saved(callback):
@@ -174,7 +188,7 @@ def on_image_saved(callback): The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
- add_callback(callbacks_image_saved, callback)
+ add_callback(callback_map['callbacks_image_saved'], callback)
def on_cfg_denoiser(callback):
@@ -182,5 +196,4 @@ def on_cfg_denoiser(callback): The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
- add_callback(callbacks_cfg_denoiser, callback)
-
+ add_callback(callback_map['callbacks_cfg_denoiser'], callback)
diff --git a/modules/scripts.py b/modules/scripts.py index 366c90d7..637b2329 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -3,7 +3,6 @@ import sys import traceback
from collections import namedtuple
-import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index c7c414ef..783992d2 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -24,11 +24,15 @@ samplers_k_diffusion = [ ('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
]
samplers_data_k_diffusion = [
diff --git a/modules/shared.py b/modules/shared.py index a9e28b9c..e8bacd3c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,10 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
+parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
+parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
+parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
cmd_opts = parser.parse_args()
restricted_opts = {
@@ -147,9 +151,9 @@ class State: self.interrupted = True
def nextjob(self):
- if opts.show_progress_every_n_steps == -1:
+ if opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
-
+
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
@@ -198,7 +202,7 @@ class State: return
if self.current_latent is None:
return
-
+
if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
else:
@@ -217,8 +221,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = []
-localization.list_localizations(cmd_opts.localizations_dir)
-
def realesrgan_models_names():
import modules.realesrgan_model
@@ -317,6 +319,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
@@ -406,7 +409,8 @@ class Options: if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
- comp_args = opts.data_labels[key].component_args
+ info = opts.data_labels.get(key, None)
+ comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
diff --git a/modules/ui.py b/modules/ui.py index db6d59d5..030d4144 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None): gr.processing_utils.save_pil_to_file = save_pil_to_file
-def wrap_gradio_call(func, extra_outputs=None):
+def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
- run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
+ run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
@@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None): res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
+ shared.state.skipped = False
+ shared.state.interrupted = False
+ shared.state.job_count = 0
+
+ if not add_stats:
+ return tuple(res)
+
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
- if (elapsed_m > 0):
+ if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
@@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None): # last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
- shared.state.skipped = False
- shared.state.interrupted = False
- shared.state.job_count = 0
-
return tuple(res)
return f
@@ -1138,7 +1141,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[html, generation_info, html2],
)
- with gr.Blocks() as modelmerger_interface:
+ with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
@@ -1158,7 +1161,7 @@ def create_ui(wrap_gradio_gpu_call): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
- with gr.Blocks() as train_interface:
+ with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
@@ -1423,15 +1426,14 @@ def create_ui(wrap_gradio_gpu_call): if info.refresh is not None:
if is_quicksettings:
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
with gr.Row(variant="compact"):
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
-
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
return res
@@ -1442,7 +1444,7 @@ def create_ui(wrap_gradio_gpu_call): opts.reorder()
def run_settings(*args):
- changed = 0
+ changed = []
for key, value, comp in zip(opts.data_labels.keys(), args, components):
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
@@ -1452,18 +1454,20 @@ def create_ui(wrap_gradio_gpu_call): continue
oldval = opts.data.get(key, None)
-
- setattr(opts, key, value)
-
+ try:
+ setattr(opts, key, value)
+ except RuntimeError:
+ continue
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
- changed += 1
-
- opts.save(shared.config_filename)
-
- return opts.dumpjson(), f'{changed} settings changed.'
+ changed.append(key)
+ try:
+ opts.save(shared.config_filename)
+ except RuntimeError:
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.'
def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
@@ -1567,11 +1571,10 @@ def create_ui(wrap_gradio_gpu_call): shared.state.need_restart = True
restart_gradio.click(
-
fn=request_restart,
+ _js='restart_reload',
inputs=[],
outputs=[],
- _js='restart_reload'
)
if column is not None:
@@ -1641,6 +1644,17 @@ def create_ui(wrap_gradio_gpu_call): outputs=[component, text_settings],
)
+ component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
+
+ def get_settings_values():
+ return [getattr(opts, key) for key in component_keys]
+
+ demo.load(
+ fn=get_settings_values,
+ inputs=[],
+ outputs=[component_dict[k] for k in component_keys],
+ )
+
def modelmerger(*args):
try:
results = modules.extras.run_modelmerger(*args)
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index a81de9a7..02ab9643 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -140,13 +140,15 @@ def install_extension_from_url(dirname, url): shutil.rmtree(tmpdir, True)
-def install_extension_from_index(url):
+def install_extension_from_index(url, hide_tags):
ext_table, message = install_extension_from_url(None, url)
- return refresh_available_extensions_from_data(), ext_table, message
+ code, _ = refresh_available_extensions_from_data(hide_tags)
+ return code, ext_table, message
-def refresh_available_extensions(url):
+
+def refresh_available_extensions(url, hide_tags):
global available_extensions
import urllib.request
@@ -155,13 +157,25 @@ def refresh_available_extensions(url): available_extensions = json.loads(text)
- return url, refresh_available_extensions_from_data(), ''
+ code, tags = refresh_available_extensions_from_data(hide_tags)
+
+ return url, code, gr.CheckboxGroup.update(choices=tags), ''
+
+
+def refresh_available_extensions_for_tags(hide_tags):
+ code, _ = refresh_available_extensions_from_data(hide_tags)
+ return code, ''
-def refresh_available_extensions_from_data():
+
+def refresh_available_extensions_from_data(hide_tags):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+ tags = available_extensions.get("tags", {})
+ tags_to_hide = set(hide_tags)
+ hidden = 0
+
code = f"""<!-- {time.time()} -->
<table id="available_extensions">
<thead>
@@ -178,17 +192,24 @@ def refresh_available_extensions_from_data(): name = ext.get("name", "noname")
url = ext.get("url", None)
description = ext.get("description", "")
+ extension_tags = ext.get("tags", [])
if url is None:
continue
+ if len([x for x in extension_tags if x in tags_to_hide]) > 0:
+ hidden += 1
+ continue
+
existing = installed_extension_urls.get(normalize_git_url(url), None)
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
+ tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
+
code += f"""
<tr>
- <td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
+ <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
<td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>
@@ -199,7 +220,10 @@ def refresh_available_extensions_from_data(): </table>
"""
- return code
+ if hidden > 0:
+ code += f"<p>Extension hidden: {hidden}</p>"
+
+ return code, list(tags)
def create_ui():
@@ -238,21 +262,30 @@ def create_ui(): extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
+ with gr.Row():
+ hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
+
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
- inputs=[available_extensions_index],
- outputs=[available_extensions_index, available_extensions_table, install_result],
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ inputs=[available_extensions_index, hide_tags],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
- inputs=[extension_to_install],
+ inputs=[extension_to_install, hide_tags],
outputs=[available_extensions_table, extensions_table, install_result],
)
+ hide_tags.change(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
+ inputs=[hide_tags],
+ outputs=[available_extensions_table, install_result]
+ )
+
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
diff --git a/modules/upscaler.py b/modules/upscaler.py index 83fde7ca..c4e6e6bd 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -57,10 +57,18 @@ class Upscaler: self.scale = scale dest_w = img.width * scale dest_h = img.height * scale + for i in range(3): - if img.width > dest_w and img.height > dest_h: - break + shape = (img.width, img.height) + img = self.do_upscale(img, selected_model) + + if shape == (img.width, img.height): + break + + if img.width >= dest_w and img.height >= dest_h: + break + if img.width != dest_w or img.height != dest_h: img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) |