aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py124
-rw-r--r--modules/api/processing.py106
-rw-r--r--modules/deepbooru.py140
-rw-r--r--modules/devices.py25
-rw-r--r--modules/extras.py116
-rw-r--r--modules/generation_parameters_copypaste.py34
-rw-r--r--modules/hypernetwork.py98
-rw-r--r--modules/hypernetworks/hypernetwork.py478
-rw-r--r--modules/hypernetworks/ui.py61
-rw-r--r--modules/images.py24
-rw-r--r--modules/images_history.py424
-rw-r--r--modules/img2img.py3
-rw-r--r--modules/interrogate.py31
-rw-r--r--modules/localization.py31
-rw-r--r--modules/lowvram.py9
-rw-r--r--modules/ngrok.py17
-rw-r--r--modules/paths.py1
-rw-r--r--modules/processing.py259
-rw-r--r--modules/prompt_parser.py2
-rw-r--r--modules/safe.py117
-rw-r--r--modules/script_callbacks.py53
-rw-r--r--modules/scripts.py252
-rw-r--r--modules/sd_hijack.py127
-rw-r--r--modules/sd_hijack_inpainting.py331
-rw-r--r--modules/sd_hijack_optimizations.py156
-rw-r--r--modules/sd_models.py111
-rw-r--r--modules/sd_samplers.py204
-rw-r--r--modules/shared.py120
-rw-r--r--modules/styles.py4
-rw-r--r--modules/swinir_model.py35
-rw-r--r--modules/swinir_model_arch_v2.py1017
-rw-r--r--modules/textual_inversion/dataset.py82
-rw-r--r--modules/textual_inversion/image_embedding.py220
-rw-r--r--modules/textual_inversion/learn_schedule.py69
-rw-r--r--modules/textual_inversion/preprocess.py171
-rw-r--r--modules/textual_inversion/test_embedding.pngbin0 -> 489220 bytes
-rw-r--r--modules/textual_inversion/textual_inversion.py143
-rw-r--r--modules/textual_inversion/ui.py6
-rw-r--r--modules/txt2img.py12
-rw-r--r--modules/ui.py733
40 files changed, 5180 insertions, 766 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
new file mode 100644
index 00000000..3caa83a4
--- /dev/null
+++ b/modules/api/api.py
@@ -0,0 +1,124 @@
+from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.sd_samplers import all_samplers
+from modules.extras import run_pnginfo
+import modules.shared as shared
+import uvicorn
+from fastapi import Body, APIRouter, HTTPException
+from fastapi.responses import JSONResponse
+from pydantic import BaseModel, Field, Json
+import json
+import io
+import base64
+from PIL import Image
+
+sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+
+class TextToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: Json
+ info: Json
+
+class ImageToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: Json
+ info: Json
+
+
+class Api:
+ def __init__(self, app, queue_lock):
+ self.router = APIRouter()
+ self.app = app
+ self.queue_lock = queue_lock
+ self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
+
+ def __base64_to_image(self, base64_string):
+ # if has a comma, deal with prefix
+ if "," in base64_string:
+ base64_string = base64_string.split(",")[1]
+ imgdata = base64.b64decode(base64_string)
+ # convert base64 to PIL image
+ return Image.open(io.BytesIO(imgdata))
+
+ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
+ sampler_index = sampler_to_index(txt2imgreq.sampler_index)
+
+ if sampler_index is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
+ populate = txt2imgreq.copy(update={ # Override __init__ params
+ "sd_model": shared.sd_model,
+ "sampler_index": sampler_index[0],
+ "do_not_save_samples": True,
+ "do_not_save_grid": True
+ }
+ )
+ p = StableDiffusionProcessingTxt2Img(**vars(populate))
+ # Override object param
+ with self.queue_lock:
+ processed = process_images(p)
+
+ b64images = []
+ for i in processed.images:
+ buffer = io.BytesIO()
+ i.save(buffer, format="png")
+ b64images.append(base64.b64encode(buffer.getvalue()))
+
+ return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
+
+
+
+ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+ sampler_index = sampler_to_index(img2imgreq.sampler_index)
+
+ if sampler_index is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
+
+ init_images = img2imgreq.init_images
+ if init_images is None:
+ raise HTTPException(status_code=404, detail="Init image not found")
+
+ mask = img2imgreq.mask
+ if mask:
+ mask = self.__base64_to_image(mask)
+
+
+ populate = img2imgreq.copy(update={ # Override __init__ params
+ "sd_model": shared.sd_model,
+ "sampler_index": sampler_index[0],
+ "do_not_save_samples": True,
+ "do_not_save_grid": True,
+ "mask": mask
+ }
+ )
+ p = StableDiffusionProcessingImg2Img(**vars(populate))
+
+ imgs = []
+ for img in init_images:
+ img = self.__base64_to_image(img)
+ imgs = [img] * p.batch_size
+
+ p.init_images = imgs
+ # Override object param
+ with self.queue_lock:
+ processed = process_images(p)
+
+ b64images = []
+ for i in processed.images:
+ buffer = io.BytesIO()
+ i.save(buffer, format="png")
+ b64images.append(base64.b64encode(buffer.getvalue()))
+
+ return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
+
+ def extrasapi(self):
+ raise NotImplementedError
+
+ def pnginfoapi(self):
+ raise NotImplementedError
+
+ def launch(self, server_name, port):
+ self.app.include_router(self.router)
+ uvicorn.run(self.app, host=server_name, port=port)
diff --git a/modules/api/processing.py b/modules/api/processing.py
new file mode 100644
index 00000000..f551fa35
--- /dev/null
+++ b/modules/api/processing.py
@@ -0,0 +1,106 @@
+from array import array
+from inflection import underscore
+from typing import Any, Dict, Optional
+from pydantic import BaseModel, Field, create_model
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
+import inspect
+
+
+API_NOT_ALLOWED = [
+ "self",
+ "kwargs",
+ "sd_model",
+ "outpath_samples",
+ "outpath_grids",
+ "sampler_index",
+ "do_not_save_samples",
+ "do_not_save_grid",
+ "extra_generation_params",
+ "overlay_images",
+ "do_not_reload_embeddings",
+ "seed_enable_extras",
+ "prompt_for_display",
+ "sampler_noise_scheduler_override",
+ "ddim_discretize"
+]
+
+class ModelDef(BaseModel):
+ """Assistance Class for Pydantic Dynamic Model Generation"""
+
+ field: str
+ field_alias: str
+ field_type: Any
+ field_value: Any
+
+
+class PydanticModelGenerator:
+ """
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
+ source_data is a snapshot of the default values produced by the class
+ params are the names of the actual keys required by __init__
+ """
+
+ def __init__(
+ self,
+ model_name: str = None,
+ class_instance = None,
+ additional_fields = None,
+ ):
+ def field_type_generator(k, v):
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
+ # print(k, v.annotation, v.default)
+ field_type = v.annotation
+
+ return Optional[field_type]
+
+ def merge_class_params(class_):
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+ parameters = {}
+ for classes in all_classes:
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+ return parameters
+
+
+ self._model_name = model_name
+ self._class_data = merge_class_params(class_instance)
+ self._model_def = [
+ ModelDef(
+ field=underscore(k),
+ field_alias=k,
+ field_type=field_type_generator(k, v),
+ field_value=v.default
+ )
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
+ ]
+
+ for fields in additional_fields:
+ self._model_def.append(ModelDef(
+ field=underscore(fields["key"]),
+ field_alias=fields["key"],
+ field_type=fields["type"],
+ field_value=fields["default"]))
+
+ def generate_model(self):
+ """
+ Creates a pydantic BaseModel
+ from the json and overrides provided at initialization
+ """
+ fields = {
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
+ }
+ DynamicModel = create_model(self._model_name, **fields)
+ DynamicModel.__config__.allow_population_by_field_name = True
+ DynamicModel.__config__.allow_mutation = True
+ return DynamicModel
+
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingTxt2Img",
+ StableDiffusionProcessingTxt2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingImg2Img",
+ StableDiffusionProcessingImg2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
+).generate_model() \ No newline at end of file
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 7e3c0618..8bbc90a4 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -1,29 +1,122 @@
import os.path
from concurrent.futures import ProcessPoolExecutor
-from multiprocessing import get_context
+import multiprocessing
+import time
+import re
+re_special = re.compile(r'([\\()])')
-def _load_tf_and_return_tags(pil_image, threshold):
+def get_deepbooru_tags(pil_image):
+ """
+ This method is for running only one image at a time for simple use. Used to the img2img interrogate.
+ """
+ from modules import shared # prevents circular reference
+
+ try:
+ create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
+ return get_tags_from_process(pil_image)
+ finally:
+ release_process()
+
+
+OPT_INCLUDE_RANKS = "include_ranks"
+def create_deepbooru_opts():
+ from modules import shared
+
+ return {
+ "use_spaces": shared.opts.deepbooru_use_spaces,
+ "use_escape": shared.opts.deepbooru_escape,
+ "alpha_sort": shared.opts.deepbooru_sort_alpha,
+ OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
+ }
+
+
+def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
+ model, tags = get_deepbooru_tags_model()
+ while True: # while process is running, keep monitoring queue for new image
+ pil_image = queue.get()
+ if pil_image == "QUIT":
+ break
+ else:
+ deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
+
+
+def create_deepbooru_process(threshold, deepbooru_opts):
+ """
+ Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
+ to be processed in a row without reloading the model or creating a new process. To return the data, a shared
+ dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
+ to the dictionary and the method adding the image to the queue should wait for this value to be updated with
+ the tags.
+ """
+ from modules import shared # prevents circular reference
+ context = multiprocessing.get_context("spawn")
+ shared.deepbooru_process_manager = context.Manager()
+ shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
+ shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
+ shared.deepbooru_process_return["value"] = -1
+ shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
+ shared.deepbooru_process.start()
+
+
+def get_tags_from_process(image):
+ from modules import shared
+
+ shared.deepbooru_process_return["value"] = -1
+ shared.deepbooru_process_queue.put(image)
+ while shared.deepbooru_process_return["value"] == -1:
+ time.sleep(0.2)
+ caption = shared.deepbooru_process_return["value"]
+ shared.deepbooru_process_return["value"] = -1
+
+ return caption
+
+
+def release_process():
+ """
+ Stops the deepbooru process to return used memory
+ """
+ from modules import shared # prevents circular reference
+ shared.deepbooru_process_queue.put("QUIT")
+ shared.deepbooru_process.join()
+ shared.deepbooru_process_queue = None
+ shared.deepbooru_process = None
+ shared.deepbooru_process_return = None
+ shared.deepbooru_process_manager = None
+
+def get_deepbooru_tags_model():
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
-
this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time
import zipfile
from basicsr.utils.download_util import load_file_from_url
- load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
- model_path)
+ load_file_from_url(
+ r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
+ model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path)
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(
- model_path, compile_model=True
+ model_path, compile_model=False
)
+ return model, tags
+
+
+def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
+ import deepdanbooru as dd
+ import tensorflow as tf
+ import numpy as np
+
+ alpha_sort = deepbooru_opts['alpha_sort']
+ use_spaces = deepbooru_opts['use_spaces']
+ use_escape = deepbooru_opts['use_escape']
+ include_ranks = deepbooru_opts['include_ranks']
width = model.input_shape[2]
height = model.input_shape[1]
@@ -46,28 +139,35 @@ def _load_tf_and_return_tags(pil_image, threshold):
for i, tag in enumerate(tags):
result_dict[tag] = y[i]
- result_tags_out = []
+
+ unsorted_tags_in_theshold = []
result_tags_print = []
for tag in tags:
if result_dict[tag] >= threshold:
if tag.startswith("rating:"):
continue
- result_tags_out.append(tag)
+ unsorted_tags_in_theshold.append((result_dict[tag], tag))
result_tags_print.append(f'{result_dict[tag]} {tag}')
- print('\n'.join(sorted(result_tags_print, reverse=True)))
-
- return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
+ # sort tags
+ result_tags_out = []
+ sort_ndx = 0
+ if alpha_sort:
+ sort_ndx = 1
+ # sort by reverse by likelihood and normal for alpha, and format tag text as requested
+ unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
+ for weight, tag in unsorted_tags_in_theshold:
+ tag_outformat = tag
+ if use_spaces:
+ tag_outformat = tag_outformat.replace('_', ' ')
+ if use_escape:
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
+ if include_ranks:
+ tag_outformat = f"({tag_outformat}:{weight:.3f})"
-def subprocess_init_no_cuda():
- import os
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+ result_tags_out.append(tag_outformat)
+ print('\n'.join(sorted(result_tags_print, reverse=True)))
-def get_deepbooru_tags(pil_image, threshold=0.5):
- context = get_context('spawn')
- with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
- f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
- ret = f.result() # will rethrow any exceptions
- return ret \ No newline at end of file
+ return ', '.join(result_tags_out)
diff --git a/modules/devices.py b/modules/devices.py
index 0158b11f..dc1f3cdd 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,7 +1,6 @@
+import sys, os, shlex
import contextlib
-
import torch
-
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]: return args[x+1]
+ return None
def get_optimal_device():
if torch.cuda.is_available():
- return torch.device("cuda")
+ from modules import shared
+
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ return torch.device(cuda_device)
+ else:
+ return torch.device("cuda")
if has_mps:
return torch.device("mps")
@@ -34,8 +45,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
+dtype_vae = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
@@ -59,9 +71,12 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device)
-def autocast():
+def autocast(disable=False):
from modules import shared
+ if disable:
+ return contextlib.nullcontext()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
diff --git a/modules/extras.py b/modules/extras.py
index 39dd3806..22c5a1c1 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,3 +1,4 @@
+import math
import os
import numpy as np
@@ -19,26 +20,43 @@ import gradio as gr
cached_images = {}
-def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
devices.torch_gc()
imageArr = []
# Also keep track of original file names
imageNameArr = []
-
+ outputs = []
+
if extras_mode == 1:
#convert file to pillow image
for img in image_folder:
- image = Image.fromarray(np.array(Image.open(img)))
+ image = Image.open(img)
imageArr.append(image)
imageNameArr.append(os.path.splitext(img.orig_name)[0])
+ elif extras_mode == 2:
+ assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
+
+ if input_dir == '':
+ return outputs, "Please select an input directory.", ''
+ image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
+ for img in image_list:
+ try:
+ image = Image.open(img)
+ except Exception:
+ continue
+ imageArr.append(image)
+ imageNameArr.append(img)
else:
imageArr.append(image)
imageNameArr.append(None)
- outpath = opts.outdir_samples or opts.outdir_extras_samples
+ if extras_mode == 2 and output_dir != '':
+ outpath = output_dir
+ else:
+ outpath = opts.outdir_samples or opts.outdir_extras_samples
- outputs = []
+
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
@@ -67,25 +85,35 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res
+ if resize_mode == 1:
+ upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
+ crop_info = " (crop)" if upscaling_crop else ""
+ info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
+
if upscaling_resize != 1.0:
- def upscale(image, scaler_index, resize):
+ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
+ key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
+ resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
c = cached_images.get(key)
if c is None:
upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ if mode == 1 and crop:
+ cropped = Image.new("RGB", (resize_w, resize_h))
+ cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
+ c = cropped
cached_images[key] = c
return c
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize)
+ res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize)
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
res = Image.blend(res, res2, extras_upscaler_2_visibility)
@@ -93,16 +121,21 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
+
+ if opts.use_original_name_batch and image_name != None:
+ basename = os.path.splitext(os.path.basename(image_name))[0]
+ else:
+ basename = ''
- images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
- forced_filename=image_name if opts.use_original_name_batch else None)
+ images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
image.info["extras"] = info
- outputs.append(image)
+ if extras_mode != 2 or show_extras_results :
+ outputs.append(image)
devices.torch_gc()
@@ -149,48 +182,63 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
- # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
+def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
- # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def sigmoid(theta0, theta1, alpha):
- alpha = alpha * alpha * (3 - (2 * alpha))
- return theta0 + ((theta1 - theta0) * alpha)
+ def get_difference(theta1, theta2):
+ return theta1 - theta2
- # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def inv_sigmoid(theta0, theta1, alpha):
- import math
- alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
- return theta0 + ((theta1 - theta0) * alpha)
+ def add_difference(theta0, theta1_2_diff, alpha):
+ return theta0 + (alpha * theta1_2_diff)
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
+ teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
+ theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
-
- theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
+ if teritary_model_info is not None:
+ print(f"Loading {teritary_model_info.filename}...")
+ teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
+ theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
+ else:
+ teritary_model = None
+ theta_2 = None
+
theta_funcs = {
- "Weighted Sum": weighted_sum,
- "Sigmoid": sigmoid,
- "Inverse Sigmoid": inv_sigmoid,
+ "Weighted sum": (None, weighted_sum),
+ "Add difference": (get_difference, add_difference),
}
- theta_func = theta_funcs[interp_method]
+ theta_func1, theta_func2 = theta_funcs[interp_method]
print(f"Merging...")
+
+ if theta_func1:
+ for key in tqdm.tqdm(theta_1.keys()):
+ if 'model' in key:
+ if key in theta_2:
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
+ theta_1[key] = theta_func1(theta_1[key], t2)
+ else:
+ theta_1[key] = torch.zeros_like(theta_1[key])
+ del theta_2, teritary_model
+
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
- theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
+
+ theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
+
if save_as_half:
theta_0[key] = theta_0[key].half()
-
+
+ # I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
@@ -199,7 +247,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
+ filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(ckpt_dir, filename)
@@ -209,4 +257,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
sd_models.list_models()
print(f"Checkpoint saved.")
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index ac1ba7f4..f73647da 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,13 +1,25 @@
+import os
import re
import gradio as gr
+from modules.shared import script_path
+from modules import shared
-re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
+re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())
+def quote(text):
+ if ',' not in str(text):
+ return text
+
+ text = str(text)
+ text = text.replace('\\', '\\\\')
+ text = text.replace('"', '\\"')
+ return f'"{text}"'
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -42,11 +54,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
prompt += ("" if prompt == "" else "\n") + line
- if len(prompt) > 0:
- res["Prompt"] = prompt
-
- if len(negative_prompt) > 0:
- res["Negative prompt"] = negative_prompt
+ res["Prompt"] = prompt
+ res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
m = re_imagesize.match(v)
@@ -61,6 +70,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, js=None):
def paste_func(prompt):
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
+ filename = os.path.join(script_path, "params.txt")
+ if os.path.exists(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ prompt = file.read()
+
params = parse_generation_parameters(prompt)
res = []
@@ -77,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else:
try:
valtype = type(output.value)
- val = valtype(v)
+
+ if valtype == bool and v == "False":
+ val = False
+ else:
+ val = valtype(v)
+
res.append(gr.update(value=val))
except Exception:
res.append(gr.update())
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py
deleted file mode 100644
index 498bc9d8..00000000
--- a/modules/hypernetwork.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import glob
-import os
-import sys
-import traceback
-
-import torch
-
-from ldm.util import default
-from modules import devices, shared
-import torch
-from torch import einsum
-from einops import rearrange, repeat
-
-
-class HypernetworkModule(torch.nn.Module):
- def __init__(self, dim, state_dict):
- super().__init__()
-
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
-
- self.load_state_dict(state_dict, strict=True)
- self.to(devices.device)
-
- def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
-
-
-class Hypernetwork:
- filename = None
- name = None
-
- def __init__(self, filename):
- self.filename = filename
- self.name = os.path.splitext(os.path.basename(filename))[0]
- self.layers = {}
-
- state_dict = torch.load(filename, map_location='cpu')
- for size, sd in state_dict.items():
- self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
-
-
-def list_hypernetworks(path):
- res = {}
- for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
- name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
- return res
-
-
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- if path is not None:
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork(path)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print(f"Unloading hypernetwork")
-
- shared.loaded_hypernetwork = None
-
-
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k = self.to_k(hypernetwork_layers[0](context))
- v = self.to_v(hypernetwork_layers[1](context))
- else:
- k = self.to_k(context)
- v = self.to_v(context)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
-
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
-
- if mask is not None:
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
-
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
-
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
new file mode 100644
index 00000000..98a7b62e
--- /dev/null
+++ b/modules/hypernetworks/hypernetwork.py
@@ -0,0 +1,478 @@
+import csv
+import datetime
+import glob
+import html
+import os
+import sys
+import traceback
+
+import modules.textual_inversion.dataset
+import torch
+import tqdm
+from einops import rearrange, repeat
+from ldm.util import default
+from modules import devices, processing, sd_models, shared
+from modules.textual_inversion import textual_inversion
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
+
+from statistics import stdev, mean
+
+class HypernetworkModule(torch.nn.Module):
+ multiplier = 1.0
+ activation_dict = {
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ }
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ super().__init__()
+
+ assert layer_structure is not None, "layer_structure must not be None"
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
+
+ linears = []
+ for i in range(len(layer_structure) - 1):
+
+ # Add a fully-connected layer
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
+ # Add an activation func
+ if activation_func == "linear" or activation_func is None:
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
+ # Add layer normalization
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+
+ # Add dropout expect last layer
+ if use_dropout and i < len(layer_structure) - 3:
+ linears.append(torch.nn.Dropout(p=0.3))
+
+ self.linear = torch.nn.Sequential(*linears)
+
+ if state_dict is not None:
+ self.fix_old_state_dict(state_dict)
+ self.load_state_dict(state_dict)
+ else:
+ for layer in self.linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ layer.weight.data.normal_(mean=0.0, std=0.01)
+ layer.bias.data.zero_()
+
+ self.to(devices.device)
+
+ def fix_old_state_dict(self, state_dict):
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+
+ for fr, to in changes.items():
+ x = state_dict.get(fr, None)
+ if x is None:
+ continue
+
+ del state_dict[fr]
+ state_dict[to] = x
+
+ def forward(self, x):
+ return x + self.linear(x) * self.multiplier
+
+ def trainables(self):
+ layer_structure = []
+ for layer in self.linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ layer_structure += [layer.weight, layer.bias]
+ return layer_structure
+
+
+def apply_strength(value=None):
+ HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
+
+
+class Hypernetwork:
+ filename = None
+ name = None
+
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ self.filename = None
+ self.name = name
+ self.layers = {}
+ self.step = 0
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+ self.layer_structure = layer_structure
+ self.activation_func = activation_func
+ self.add_layer_norm = add_layer_norm
+ self.use_dropout = use_dropout
+
+ for size in enable_sizes or []:
+ self.layers[size] = (
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ )
+
+ def weights(self):
+ res = []
+
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.train()
+ res += layer.trainables()
+
+ return res
+
+ def save(self, filename):
+ state_dict = {}
+
+ for k, v in self.layers.items():
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
+
+ state_dict['step'] = self.step
+ state_dict['name'] = self.name
+ state_dict['layer_structure'] = self.layer_structure
+ state_dict['activation_func'] = self.activation_func
+ state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['use_dropout'] = self.use_dropout
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+
+ torch.save(state_dict, filename)
+
+ def load(self, filename):
+ self.filename = filename
+ if self.name is None:
+ self.name = os.path.splitext(os.path.basename(filename))[0]
+
+ state_dict = torch.load(filename, map_location='cpu')
+
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ self.activation_func = state_dict.get('activation_func', None)
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.use_dropout = state_dict.get('use_dropout', False)
+
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ self.layers[size] = (
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ )
+
+ self.name = state_dict.get('name', self.name)
+ self.step = state_dict.get('step', 0)
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+
+
+def list_hypernetworks(path):
+ res = {}
+ for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ name = os.path.splitext(os.path.basename(filename))[0]
+ res[name] = filename
+ return res
+
+
+def load_hypernetwork(filename):
+ path = shared.hypernetworks.get(filename, None)
+ if path is not None:
+ print(f"Loading hypernetwork {filename}")
+ try:
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ if shared.loaded_hypernetwork is not None:
+ print(f"Unloading hypernetwork")
+
+ shared.loaded_hypernetwork = None
+
+
+def find_closest_hypernetwork_name(search: str):
+ if not search:
+ return None
+ search = search.lower()
+ applicable = [name for name in shared.hypernetworks if search in name.lower()]
+ if not applicable:
+ return None
+ applicable = sorted(applicable, key=lambda name: len(name))
+ return applicable[0]
+
+
+def apply_hypernetwork(hypernetwork, context, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context, context
+
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
+
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
+
+
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+def stack_conds(conds):
+ if len(conds) == 1:
+ return torch.stack(conds)
+
+ # same as in reconstruct_multicond_batch
+ token_count = max([x.shape[0] for x in conds])
+ for i in range(len(conds)):
+ if conds[i].shape[0] != token_count:
+ last_vector = conds[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
+
+ return torch.stack(conds)
+
+
+def log_statistics(loss_info:dict, key, value):
+ if key not in loss_info:
+ loss_info[key] = [value]
+ else:
+ loss_info[key].append(value)
+ if len(loss_info) > 1024:
+ loss_info.pop(0)
+
+
+def statistics(data):
+ total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
+ recent_data = data[-32:]
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})"
+ return total_information, recent_information
+
+
+def report_statistics(loss_info:dict):
+ keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
+ for key in keys:
+ try:
+ print("Loss statistics for file " + key)
+ info, recent = statistics(loss_info[key])
+ print(info)
+ print(recent)
+ except Exception as e:
+ print(e)
+
+
+
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
+ from modules import images
+
+ assert hypernetwork_name, 'hypernetwork not selected'
+
+ path = shared.hypernetworks.get(hypernetwork_name, None)
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ shared.state.textinfo = "Initializing hypernetwork training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+ unload = shared.opts.unload_models_when_training
+
+ if save_hypernetwork_every > 0:
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
+ os.makedirs(hypernetwork_dir, exist_ok=True)
+ else:
+ hypernetwork_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ hypernetwork = shared.loaded_hypernetwork
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
+
+ size = len(ds.indexes)
+ loss_dict = {}
+ losses = torch.zeros((size,))
+ previous_mean_loss = 0
+ print("Mean loss of {} elements".format(size))
+
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+ forced_filename = "<none>"
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ return hypernetwork, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+
+ steps_without_grad = 0
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
+ for i, entries in pbar:
+ hypernetwork.step = i + ititial_step
+ if len(loss_dict) > 0:
+ previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
+
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = stack_conds([entry.cond for entry in entries]).to(devices.device)
+ # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+ del c
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+ for entry in entries:
+ log_statistics(loss_dict, entry.filename, loss.item())
+
+ optimizer.zero_grad()
+ weights[0].grad = None
+ loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
+ else:
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
+
+ optimizer.step()
+
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ raise RuntimeError("Loss diverged.")
+ pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
+
+ if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
+ hypernetwork.save(last_saved_file)
+
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
+ "loss": f"{previous_mean_loss:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.current_image = image
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
+<p>
+Loss: {previous_mean_loss:.7f}<br/>
+Step: {hypernetwork.step}<br/>
+Last prompt: {html.escape(entries[0].cond_text)}<br/>
+Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
+Last saved image: {html.escape(last_saved_image)}<br/>
+</p>
+"""
+
+ report_statistics(loss_dict)
+ checkpoint = sd_models.select_checkpoint()
+
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
+ hypernetwork.name = hypernetwork_name
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
+ hypernetwork.save(filename)
+
+ return hypernetwork, filename
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
new file mode 100644
index 00000000..2b472d87
--- /dev/null
+++ b/modules/hypernetworks/ui.py
@@ -0,0 +1,61 @@
+import html
+import os
+import re
+
+import gradio as gr
+import modules.textual_inversion.preprocess
+import modules.textual_inversion.textual_inversion
+from modules import devices, sd_hijack, shared
+from modules.hypernetworks import hypernetwork
+
+
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ if type(layer_structure) == str:
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+ name=name,
+ enable_sizes=[int(x) for x in enable_sizes],
+ layer_structure=layer_structure,
+ activation_func=activation_func,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
+ )
+ hypernet.save(fn)
+
+ shared.reload_hypernetworks()
+
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+
+
+def train_hypernetwork(*args):
+
+ initial_hypernetwork = shared.loaded_hypernetwork
+
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
+Hypernetwork saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ shared.loaded_hypernetwork = initial_hypernetwork
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ sd_hijack.apply_optimizations()
+
diff --git a/modules/images.py b/modules/images.py
index c0a90676..b9589563 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,4 +1,5 @@
import datetime
+import io
import math
import os
from collections import namedtuple
@@ -23,6 +24,10 @@ def image_grid(imgs, batch_size=1, rows=None):
rows = opts.n_rows
elif opts.n_rows == 0:
rows = batch_size
+ elif opts.grid_prevent_empty_spots:
+ rows = math.floor(math.sqrt(len(imgs)))
+ while len(imgs) % rows != 0:
+ rows -= 1
else:
rows = math.sqrt(len(imgs))
rows = round(rows)
@@ -463,3 +468,22 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn = None
return fullfn, txt_fullfn
+
+
+def image_data(data):
+ try:
+ image = Image.open(io.BytesIO(data))
+ textinfo = image.text["parameters"]
+ return textinfo, None
+ except Exception:
+ pass
+
+ try:
+ text = data.decode('utf8')
+ assert len(text) < 10000
+ return text, None
+
+ except Exception:
+ pass
+
+ return '', None
diff --git a/modules/images_history.py b/modules/images_history.py
new file mode 100644
index 00000000..bc5cf11f
--- /dev/null
+++ b/modules/images_history.py
@@ -0,0 +1,424 @@
+import os
+import shutil
+import time
+import hashlib
+import gradio
+system_bak_path = "webui_log_and_bak"
+custom_tab_name = "custom fold"
+faverate_tab_name = "favorites"
+tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name]
+def is_valid_date(date):
+ try:
+ time.strptime(date, "%Y%m%d")
+ return True
+ except:
+ return False
+
+def reduplicative_file_move(src, dst):
+ def same_name_file(basename, path):
+ name, ext = os.path.splitext(basename)
+ f_list = os.listdir(path)
+ max_num = 0
+ for f in f_list:
+ if len(f) <= len(basename):
+ continue
+ f_ext = f[-len(ext):] if len(ext) > 0 else ""
+ if f[:len(name)] == name and f_ext == ext:
+ if f[len(name)] == "(" and f[-len(ext)-1] == ")":
+ number = f[len(name)+1:-len(ext)-1]
+ if number.isdigit():
+ if int(number) > max_num:
+ max_num = int(number)
+ return f"{name}({max_num + 1}){ext}"
+ name = os.path.basename(src)
+ save_name = os.path.join(dst, name)
+ if not os.path.exists(save_name):
+ shutil.move(src, dst)
+ else:
+ name = same_name_file(name, dst)
+ shutil.move(src, os.path.join(dst, name))
+
+def traverse_all_files(curr_path, image_list, all_type=False):
+ try:
+ f_list = os.listdir(curr_path)
+ except:
+ if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"):
+ image_list.append(curr_path)
+ return image_list
+ for file in f_list:
+ file = os.path.join(curr_path, file)
+ if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"):
+ pass
+ elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
+ image_list.append(file)
+ else:
+ image_list = traverse_all_files(file, image_list)
+ return image_list
+
+def auto_sorting(dir_name):
+ bak_path = os.path.join(dir_name, system_bak_path)
+ if not os.path.exists(bak_path):
+ os.mkdir(bak_path)
+ log_file = None
+ files_list = []
+ f_list = os.listdir(dir_name)
+ for file in f_list:
+ if file == system_bak_path:
+ continue
+ file_path = os.path.join(dir_name, file)
+ if not is_valid_date(file):
+ if file[-10:].rfind(".") > 0:
+ files_list.append(file_path)
+ else:
+ files_list = traverse_all_files(file_path, files_list, all_type=True)
+
+ for file in files_list:
+ date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file)))
+ file_path = os.path.dirname(file)
+ hash_path = hashlib.md5(file_path.encode()).hexdigest()
+ path = os.path.join(dir_name, date_str, hash_path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ if log_file is None:
+ log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a")
+ log_file.write(f"{hash_path},{file_path}\n")
+ reduplicative_file_move(file, path)
+
+ date_list = []
+ f_list = os.listdir(dir_name)
+ for f in f_list:
+ if is_valid_date(f):
+ date_list.append(f)
+ elif f == system_bak_path:
+ continue
+ else:
+ try:
+ reduplicative_file_move(os.path.join(dir_name, f), bak_path)
+ except:
+ pass
+
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ if today not in date_list:
+ date_list.append(today)
+ return sorted(date_list, reverse=True)
+
+def archive_images(dir_name, date_to):
+ filenames = []
+ batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ if batch_size <= 0:
+ batch_size = opts.images_history_num_per_page * 6
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ date_to = today if date_to is None or date_to == "" else date_to
+ date_to_bak = date_to
+ if False: #opts.images_history_reconstruct_directory:
+ date_list = auto_sorting(dir_name)
+ for date in date_list:
+ if date <= date_to:
+ path = os.path.join(dir_name, date)
+ if date == today and not os.path.exists(path):
+ continue
+ filenames = traverse_all_files(path, filenames)
+ if len(filenames) > batch_size:
+ break
+ filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file))
+ else:
+ filenames = traverse_all_files(dir_name, filenames)
+ total_num = len(filenames)
+ tmparray = [(os.path.getmtime(file), file) for file in filenames ]
+ date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
+ filenames = []
+ date_list = {date_to:None}
+ date = time.strftime("%Y%m%d",time.localtime(time.time()))
+ for t, f in tmparray:
+ date = time.strftime("%Y%m%d",time.localtime(t))
+ date_list[date] = None
+ if t <= date_stamp:
+ filenames.append((t, f ,date))
+ date_list = sorted(list(date_list.keys()), reverse=True)
+ sort_array = sorted(filenames, key=lambda x:-x[0])
+ if len(sort_array) > batch_size:
+ date = sort_array[batch_size][2]
+ filenames = [x[1] for x in sort_array]
+ else:
+ date = date_to if len(sort_array) == 0 else sort_array[-1][2]
+ filenames = [x[1] for x in sort_array]
+ filenames = [x[1] for x in sort_array if x[2]>= date]
+ num = len(filenames)
+ last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
+ date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
+ date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
+ load_info = "<div style='color:#999' align='center'>"
+ load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
+ load_info += "</div>"
+ _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
+ return (
+ date_to,
+ load_info,
+ filenames,
+ 1,
+ image_list,
+ "",
+ "",
+ visible_num,
+ last_date_from,
+ gradio.update(visible=total_num > num)
+ )
+
+def delete_image(delete_num, name, filenames, image_index, visible_num):
+ if name == "":
+ return filenames, delete_num
+ else:
+ delete_num = int(delete_num)
+ visible_num = int(visible_num)
+ image_index = int(image_index)
+ index = list(filenames).index(name)
+ i = 0
+ new_file_list = []
+ for name in filenames:
+ if i >= index and i < index + delete_num:
+ if os.path.exists(name):
+ if visible_num == image_index:
+ new_file_list.append(name)
+ i += 1
+ continue
+ print(f"Delete file {name}")
+ os.remove(name)
+ visible_num -= 1
+ txt_file = os.path.splitext(name)[0] + ".txt"
+ if os.path.exists(txt_file):
+ os.remove(txt_file)
+ else:
+ print(f"Not exists file {name}")
+ else:
+ new_file_list.append(name)
+ i += 1
+ return new_file_list, 1, visible_num
+
+def save_image(file_name):
+ if file_name is not None and os.path.exists(file_name):
+ shutil.copy(file_name, opts.outdir_save)
+
+def get_recent_images(page_index, step, filenames):
+ page_index = int(page_index)
+ num_of_imgs_per_page = int(opts.images_history_num_per_page)
+ max_page_index = len(filenames) // num_of_imgs_per_page + 1
+ page_index = max_page_index if page_index == -1 else page_index + step
+ page_index = 1 if page_index < 1 else page_index
+ page_index = max_page_index if page_index > max_page_index else page_index
+ idx_frm = (page_index - 1) * num_of_imgs_per_page
+ image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page]
+ length = len(filenames)
+ visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
+ visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
+ return page_index, image_list, "", "", visible_num
+
+def loac_batch_click(date_to):
+ if date_to is None:
+ return time.strftime("%Y%m%d",time.localtime(time.time())), []
+ else:
+ return None, []
+def forward_click(last_date_from, date_to_recorder):
+ if len(date_to_recorder) == 0:
+ return None, []
+ if last_date_from == date_to_recorder[-1]:
+ date_to_recorder = date_to_recorder[:-1]
+ if len(date_to_recorder) == 0:
+ return None, []
+ return date_to_recorder[-1], date_to_recorder[:-1]
+
+def backward_click(last_date_from, date_to_recorder):
+ if last_date_from is None or last_date_from == "":
+ return time.strftime("%Y%m%d",time.localtime(time.time())), []
+ if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]:
+ date_to_recorder.append(last_date_from)
+ return last_date_from, date_to_recorder
+
+
+def first_page_click(page_index, filenames):
+ return get_recent_images(1, 0, filenames)
+
+def end_page_click(page_index, filenames):
+ return get_recent_images(-1, 0, filenames)
+
+def prev_page_click(page_index, filenames):
+ return get_recent_images(page_index, -1, filenames)
+
+def next_page_click(page_index, filenames):
+ return get_recent_images(page_index, 1, filenames)
+
+def page_index_change(page_index, filenames):
+ return get_recent_images(page_index, 0, filenames)
+
+def show_image_info(tabname_box, num, page_index, filenames):
+ file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
+ tm = "<div style='color:#999' align='right'>" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "</div>"
+ return file, tm, num, file
+
+def enable_page_buttons():
+ return gradio.update(visible=True)
+
+def change_dir(img_dir, date_to):
+ warning = None
+ try:
+ if os.path.exists(img_dir):
+ try:
+ f = os.listdir(img_dir)
+ except:
+ warning = f"'{img_dir} is not a directory"
+ else:
+ warning = "The directory is not exist"
+ except:
+ warning = "The format of the directory is incorrect"
+ if warning is None:
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True)
+ else:
+ return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False)
+
+def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
+ custom_dir = False
+ if tabname == "txt2img":
+ dir_name = opts.outdir_txt2img_samples
+ elif tabname == "img2img":
+ dir_name = opts.outdir_img2img_samples
+ elif tabname == "extras":
+ dir_name = opts.outdir_extras_samples
+ elif tabname == faverate_tab_name:
+ dir_name = opts.outdir_save
+ else:
+ custom_dir = True
+ dir_name = None
+
+ if not custom_dir:
+ d = dir_name.split("/")
+ dir_name = d[0]
+ for p in d[1:]:
+ dir_name = os.path.join(dir_name, p)
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name)
+
+ with gr.Column() as page_panel:
+ with gr.Row():
+ with gr.Column(scale=1, visible=not custom_dir) as load_batch_box:
+ load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True)
+ with gr.Column(scale=4):
+ with gr.Row():
+ img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
+ with gr.Row():
+ with gr.Column(visible=False, scale=1) as batch_panel:
+ with gr.Row():
+ forward = gr.Button('Prev batch')
+ backward = gr.Button('Next batch')
+ with gr.Column(scale=3):
+ load_info = gr.HTML(visible=not custom_dir)
+ with gr.Row(visible=False) as warning:
+ warning_box = gr.Textbox("Message", interactive=False)
+
+ with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
+ with gr.Column(scale=2):
+ with gr.Row(visible=True) as turn_page_buttons:
+ #date_to = gr.Dropdown(label="Date to")
+ first_page = gr.Button('First Page')
+ prev_page = gr.Button('Prev Page')
+ page_index = gr.Number(value=1, label="Page Index")
+ next_page = gr.Button('Next Page')
+ end_page = gr.Button('End Page')
+
+ history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num)
+ with gr.Row():
+ delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
+ delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
+
+ with gr.Column():
+ with gr.Row():
+ with gr.Column():
+ img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6)
+ gr.HTML("<hr>")
+ img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
+ img_file_time= gr.HTML()
+ with gr.Row():
+ if tabname != faverate_tab_name:
+ save_btn = gr.Button('Collect')
+ pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
+ pnginfo_send_to_img2img = gr.Button('Send to img2img')
+
+
+ # hiden items
+ with gr.Row(visible=False):
+ renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
+ batch_date_to = gr.Textbox(label="Date to")
+ visible_img_num = gr.Number()
+ date_to_recorder = gr.State([])
+ last_date_from = gr.Textbox()
+ tabname_box = gr.Textbox(tabname)
+ image_index = gr.Textbox(value=-1)
+ set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
+ filenames = gr.State()
+ all_images_list = gr.State()
+ hidden = gr.Image(type="pil")
+ info1 = gr.Textbox()
+ info2 = gr.Textbox()
+
+ img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info])
+
+ #change batch
+ change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel]
+
+ batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output)
+ batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
+ batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+
+ load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder])
+ forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
+ backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
+
+
+ #delete
+ delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
+ delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
+ if tabname != faverate_tab_name:
+ save_btn.click(save_image, inputs=[img_file_name], outputs=None)
+
+ #turn page
+ gallery_inputs = [page_index, filenames]
+ gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num]
+ first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
+ renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
+
+ first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+
+ # other funcitons
+ set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden])
+ img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
+ hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
+ switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
+ switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
+
+
+
+def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict):
+ global opts;
+ opts = sys_opts
+ loads_files_num = int(opts.images_history_num_per_page)
+ num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ if cmp_ops.browse_all_images:
+ tabs_list.append(custom_tab_name)
+ with gr.Blocks(analytics_enabled=False) as images_history:
+ with gr.Tabs() as tabs:
+ for tab in tabs_list:
+ with gr.Tab(tab):
+ with gr.Blocks(analytics_enabled=False) :
+ show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
+ gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False)
+ gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False)
+
+ return images_history
diff --git a/modules/img2img.py b/modules/img2img.py
index 24126774..8d9f7cf9 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -109,6 +109,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 635e266e..65b05d34 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -28,9 +28,11 @@ class InterrogateModels:
clip_preprocess = None
categories = None
dtype = None
+ running_on_cpu = None
def __init__(self, content_dir):
self.categories = []
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
@@ -53,26 +55,30 @@ class InterrogateModels:
def load_clip_model(self):
import clip
- model, preprocess = clip.load(clip_model_name)
+ if self.running_on_cpu:
+ model, preprocess = clip.load(clip_model_name, device="cpu")
+ else:
+ model, preprocess = clip.load(clip_model_name)
+
model.eval()
- model = model.to(shared.device)
+ model = model.to(devices.device_interrogate)
return model, preprocess
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half()
- self.blip_model = self.blip_model.to(shared.device)
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half()
- self.clip_model = self.clip_model.to(shared.device)
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype
@@ -99,11 +105,11 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
- similarity = torch.zeros((1, len(text_array))).to(shared.device)
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
@@ -116,7 +122,7 @@ class InterrogateModels:
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
@@ -140,7 +146,7 @@ class InterrogateModels:
res = caption
- clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"):
@@ -156,7 +162,10 @@ class InterrogateModels:
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
- res += ", " + match
+ if shared.opts.interrogate_return_ranks:
+ res += f", ({match}:{score/100:.3f})"
+ else:
+ res += ", " + match
except Exception:
print(f"Error interrogating", file=sys.stderr)
diff --git a/modules/localization.py b/modules/localization.py
new file mode 100644
index 00000000..b1810cda
--- /dev/null
+++ b/modules/localization.py
@@ -0,0 +1,31 @@
+import json
+import os
+import sys
+import traceback
+
+localizations = {}
+
+
+def list_localizations(dirname):
+ localizations.clear()
+
+ for file in os.listdir(dirname):
+ fn, ext = os.path.splitext(file)
+ if ext.lower() != ".json":
+ continue
+
+ localizations[fn] = os.path.join(dirname, file)
+
+
+def localization_js(current_localization_name):
+ fn = localizations.get(current_localization_name, None)
+ data = {}
+ if fn is not None:
+ try:
+ with open(fn, "r", encoding="utf8") as file:
+ data = json.load(file)
+ except Exception:
+ print(f"Error loading localization from {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ return f"var localization = {json.dumps(data)}\n"
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 7eba1349..f327c3df 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,9 +1,8 @@
import torch
-from modules.devices import get_optimal_device
+from modules import devices
module_in_gpu = None
cpu = torch.device("cpu")
-device = gpu = get_optimal_device()
def send_everything_to_cpu():
@@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
if module_in_gpu is not None:
module_in_gpu.to(cpu)
- module.to(gpu)
+ module.to(devices.device)
module_in_gpu = module
# see below for register_forward_pre_hook;
@@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
- sd_model.to(device)
+ sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
@@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
- sd_model.model.to(device)
+ sd_model.model.to(devices.device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff --git a/modules/ngrok.py b/modules/ngrok.py
new file mode 100644
index 00000000..5c5f349a
--- /dev/null
+++ b/modules/ngrok.py
@@ -0,0 +1,17 @@
+from pyngrok import ngrok, conf, exception
+
+
+def connect(token, port, region):
+ if token == None:
+ token = 'None'
+ config = conf.PyngrokConfig(
+ auth_token=token, region=region
+ )
+ try:
+ public_url = ngrok.connect(port, pyngrok_config=config).public_url
+ except exception.PyngrokNgrokError:
+ print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
+ f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
+ else:
+ print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
+ 'You can use this link after the launch is complete.')
diff --git a/modules/paths.py b/modules/paths.py
index 0519caa0..1e7a2fbc 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,6 +1,7 @@
import argparse
import os
import sys
+import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
diff --git a/modules/processing.py b/modules/processing.py
index 04aed989..ff83023c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -9,9 +9,10 @@ from PIL import Image, ImageFilter, ImageOps
import random
import cv2
from skimage import exposure
+from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -51,9 +52,15 @@ def get_correct_sampler(p):
return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
+ elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
+ return sd_samplers.samplers
-class StableDiffusionProcessing:
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
+class StableDiffusionProcessing():
+ """
+ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
+
+ """
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -80,15 +87,16 @@ class StableDiffusionProcessing:
self.extra_generation_params: dict = extra_generation_params or {}
self.overlay_images = overlay_images
self.eta = eta
+ self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize
- self.s_churn = opts.s_churn
- self.s_tmin = opts.s_tmin
- self.s_tmax = float('inf') # not representable as a standard ui option
- self.s_noise = opts.s_noise
+ self.s_churn = s_churn or opts.s_churn
+ self.s_tmin = s_tmin or opts.s_tmin
+ self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
+ self.s_noise = s_noise or opts.s_noise
if not seed_enable_extras:
self.subseed = -1
@@ -96,6 +104,13 @@ class StableDiffusionProcessing:
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
+ self.scripts = None
+ self.script_args = None
+ self.all_prompts = None
+ self.all_seeds = None
+ self.all_subseeds = None
+
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -129,7 +144,7 @@ class Processed:
self.index_of_first_image = index_of_first_image
self.styles = p.styles
self.job_timestamp = state.job_timestamp
- self.clip_skip = opts.CLIP_ignore_last_layers
+ self.clip_skip = opts.CLIP_stop_at_last_layers
self.eta = p.eta
self.ddim_discretize = p.ddim_discretize
@@ -140,7 +155,7 @@ class Processed:
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
- self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
@@ -207,7 +222,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -247,6 +262,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
+ if opts.eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + opts.eta_noise_seed_delta)
+
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -259,6 +277,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+def decode_first_stage(model, x):
+ with devices.autocast(disable=x.dtype == devices.dtype_vae):
+ x = model.decode_first_stage(x)
+
+ return x
+
+
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
@@ -274,7 +299,7 @@ def fix_seed(p):
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
- clip_skip = getattr(p, 'clip_skip', opts.CLIP_ignore_last_layers)
+ clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
generation_params = {
"Steps": p.steps,
@@ -285,7 +310,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -293,12 +318,13 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
- "Clip skip": None if clip_skip==0 else clip_skip,
+ "Clip skip": None if clip_skip <= 1 else clip_skip,
+ "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
}
generation_params.update(p.extra_generation_params)
- generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
@@ -313,17 +339,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else:
assert p.prompt is not None
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
devices.torch_gc()
seed = get_fixed_seed(p.seed)
subseed = get_fixed_seed(p.subseed)
- if p.outpath_samples is not None:
- os.makedirs(p.outpath_samples, exist_ok=True)
-
- if p.outpath_grids is not None:
- os.makedirs(p.outpath_grids, exist_ok=True)
-
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
@@ -332,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
- all_prompts = p.prompt
+ p.all_prompts = p.prompt
else:
- all_prompts = p.batch_size * p.n_iter * [p.prompt]
+ p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(seed) == list:
- all_seeds = seed
+ p.all_seeds = seed
else:
- all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
if type(subseed) == list:
- all_subseeds = subseed
+ p.all_subseeds = subseed
else:
- all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
- if os.path.exists(cmd_opts.embeddings_dir):
+ if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ if p.scripts is not None:
+ p.scripts.run_alwayson_scripts(p)
+
infotexts = []
output_images = []
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
- p.init(all_prompts, all_seeds, all_subseeds)
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
@@ -369,15 +396,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
- prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0):
break
- #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- #c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
@@ -392,15 +417,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
- if state.interrupted or state.skipped:
-
- # if we are interrupted, sample returns just noise
- # use the image collected previously in sampler loop
- samples_ddim = shared.state.current_latent
-
- samples_ddim = samples_ddim.to(devices.dtype)
-
- x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
+ samples_ddim = samples_ddim.to(devices.dtype_vae)
+ x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
@@ -479,24 +497,23 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+ return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- firstphase_width = 0
- firstphase_height = 0
- firstphase_width_truncated = 0
- firstphase_height_truncated = 0
- def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
+ def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
- self.scale_latent = scale_latent
self.denoising_strength = denoising_strength
+ self.firstphase_width = firstphase_width
+ self.firstphase_height = firstphase_height
+ self.truncate_x = 0
+ self.truncate_y = 0
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -505,54 +522,86 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
+ if self.firstphase_width == 0 or self.firstphase_height == 0:
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = self.width * self.height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ self.firstphase_width = math.ceil(scale * self.width / 64) * 64
+ self.firstphase_height = math.ceil(scale * self.height / 64) * 64
+ firstphase_width_truncated = int(scale * self.width)
+ firstphase_height_truncated = int(scale * self.height)
+
+ else:
+
+ width_ratio = self.width / self.firstphase_width
+ height_ratio = self.height / self.firstphase_height
+
+ if width_ratio > height_ratio:
+ firstphase_width_truncated = self.firstphase_width
+ firstphase_height_truncated = self.firstphase_width * self.height / self.width
+ else:
+ firstphase_width_truncated = self.firstphase_height * self.width / self.height
+ firstphase_height_truncated = self.firstphase_height
+
+ self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
+ self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- self.firstphase_width_truncated = int(scale * self.width)
- self.firstphase_height_truncated = int(scale * self.height)
+ def create_dummy_mask(self, x, width=None, height=None):
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ height = height or self.height
+ width = width or self.width
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ else:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
-
- truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
- truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
- samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
- if self.scale_latent:
+ if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
else:
- decoded_samples = self.sd_model.decode_first_stage(samples)
+ decoded_samples = decode_first_stage(self.sd_model, samples)
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
- decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
- else:
- lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- batch_images = []
- for i, x_sample in enumerate(lowres_samples):
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- image = Image.fromarray(x_sample)
- image = images.resize_image(0, image, self.width, self.height)
- image = np.array(image).astype(np.float32) / 255.0
- image = np.moveaxis(image, 2, 0)
- batch_images.append(image)
-
- decoded_samples = torch.from_numpy(np.array(batch_images))
- decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
+ batch_images = []
+ for i, x_sample in enumerate(lowres_samples):
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ image = Image.fromarray(x_sample)
+ image = images.resize_image(0, image, self.width, self.height)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ batch_images.append(image)
+
+ decoded_samples = torch.from_numpy(np.array(batch_images))
+ decoded_samples = decoded_samples.to(shared.device)
+ decoded_samples = 2. * decoded_samples - 1.
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
@@ -566,7 +615,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
return samples
@@ -574,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
+ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -592,6 +641,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None
self.nmask = None
+ self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
@@ -664,6 +714,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
+ self.color_corrections = self.color_corrections * self.batch_size
+
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
@@ -693,10 +747,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ if self.image_mask is not None:
+ conditioning_mask = np.array(self.image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ conditioning_mask = conditioning_mask.to(image.device)
+ conditioning_image = image * (1.0 - conditioning_mask)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
+ else:
+ self.image_conditioning = torch.zeros(
+ self.init_latent.shape[0], 5, 1, 1,
+ dtype=self.init_latent.dtype,
+ device=self.init_latent.device
+ )
+
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
@@ -704,4 +787,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices.torch_gc()
- return samples
+ return samples \ No newline at end of file
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 919d5d31..f70872c4 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -275,7 +275,7 @@ re_attention = re.compile(r"""
def parse_prompt_attention(text):
"""
- Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
diff --git a/modules/safe.py b/modules/safe.py
new file mode 100644
index 00000000..399165a1
--- /dev/null
+++ b/modules/safe.py
@@ -0,0 +1,117 @@
+# this code is adapted from the script contributed by anon from /h/
+
+import io
+import pickle
+import collections
+import sys
+import traceback
+
+import torch
+import numpy
+import _codecs
+import zipfile
+import re
+
+
+# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
+
+
+def encode(*args):
+ out = _codecs.encode(*args)
+ return out
+
+
+class RestrictedUnpickler(pickle.Unpickler):
+ def persistent_load(self, saved_id):
+ assert saved_id[0] == 'storage'
+ return TypedStorage()
+
+ def find_class(self, module, name):
+ if module == 'collections' and name == 'OrderedDict':
+ return getattr(collections, name)
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
+ return getattr(torch._utils, name)
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
+ return getattr(torch, name)
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
+ return getattr(torch.nn.modules.container, name)
+ if module == 'numpy.core.multiarray' and name == 'scalar':
+ return numpy.core.multiarray.scalar
+ if module == 'numpy' and name == 'dtype':
+ return numpy.dtype
+ if module == '_codecs' and name == 'encode':
+ return encode
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
+ import pytorch_lightning.callbacks
+ return pytorch_lightning.callbacks.model_checkpoint
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
+ import pytorch_lightning.callbacks.model_checkpoint
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
+ if module == "__builtin__" and name == 'set':
+ return set
+
+ # Forbid everything else.
+ raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+
+
+allowed_zip_names = ["archive/data.pkl", "archive/version"]
+allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
+
+
+def check_zip_filenames(filename, names):
+ for name in names:
+ if name in allowed_zip_names:
+ continue
+ if allowed_zip_names_re.match(name):
+ continue
+
+ raise Exception(f"bad file inside {filename}: {name}")
+
+
+def check_pt(filename):
+ try:
+
+ # new pytorch format is a zip file
+ with zipfile.ZipFile(filename) as z:
+ check_zip_filenames(filename, z.namelist())
+
+ with z.open('archive/data.pkl') as file:
+ unpickler = RestrictedUnpickler(file)
+ unpickler.load()
+
+ except zipfile.BadZipfile:
+
+ # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
+ with open(filename, "rb") as file:
+ unpickler = RestrictedUnpickler(file)
+ for i in range(5):
+ unpickler.load()
+
+
+def load(filename, *args, **kwargs):
+ from modules import shared
+
+ try:
+ if not shared.cmd_opts.disable_safe_unpickle:
+ check_pt(filename)
+
+ except pickle.UnpicklingError:
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ return None
+
+ except Exception:
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ return None
+
+ return unsafe_torch_load(filename, *args, **kwargs)
+
+
+unsafe_torch_load = torch.load
+torch.load = load
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
new file mode 100644
index 00000000..5bcccd67
--- /dev/null
+++ b/modules/script_callbacks.py
@@ -0,0 +1,53 @@
+
+callbacks_model_loaded = []
+callbacks_ui_tabs = []
+callbacks_ui_settings = []
+
+
+def clear_callbacks():
+ callbacks_model_loaded.clear()
+ callbacks_ui_tabs.clear()
+
+
+def model_loaded_callback(sd_model):
+ for callback in callbacks_model_loaded:
+ callback(sd_model)
+
+
+def ui_tabs_callback():
+ res = []
+
+ for callback in callbacks_ui_tabs:
+ res += callback() or []
+
+ return res
+
+
+def ui_settings_callback():
+ for callback in callbacks_ui_settings:
+ callback()
+
+
+def on_model_loaded(callback):
+ """register a function to be called when the stable diffusion model is created; the model is
+ passed as an argument"""
+ callbacks_model_loaded.append(callback)
+
+
+def on_ui_tabs(callback):
+ """register a function to be called when the UI is creating new tabs.
+ The function must either return a None, which means no new tabs to be added, or a list, where
+ each element is a tuple:
+ (gradio_component, title, elem_id)
+
+ gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
+ title is tab text displayed to user in the UI
+ elem_id is HTML id for the tab
+ """
+ callbacks_ui_tabs.append(callback)
+
+
+def on_ui_settings(callback):
+ """register a function to be called before UI settings are populated; add your settings
+ by using shared.opts.add_option(shared.OptionInfo(...)) """
+ callbacks_ui_settings.append(callback)
diff --git a/modules/scripts.py b/modules/scripts.py
index 45230f9a..9323af3e 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,83 +1,175 @@
import os
import sys
import traceback
+from collections import namedtuple
import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared
+from modules import shared, paths, script_callbacks
+
+AlwaysVisible = object()
+
class Script:
filename = None
args_from = None
args_to = None
+ alwayson = False
+
+ infotext_fields = None
+ """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
+ parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
+ """
- # The title of the script. This is what will be displayed in the dropdown menu.
def title(self):
+ """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
+
raise NotImplementedError()
- # How the script is displayed in the UI. See https://gradio.app/docs/#components
- # for the different UI components you can use and how to create them.
- # Most UI components can return a value, such as a boolean for a checkbox.
- # The returned values are passed to the run method as parameters.
def ui(self, is_img2img):
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be an array of all components that are used in processing.
+ Values of those returned componenbts will be passed to run() and process() functions.
+ """
+
pass
- # Determines when the script should be shown in the dropdown menu via the
- # returned value. As an example:
- # is_img2img is True if the current tab is img2img, and False if it is txt2img.
- # Thus, return is_img2img to only show the script on the img2img tab.
def show(self, is_img2img):
+ """
+ is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
+
+ This function should return:
+ - False if the script should not be shown in UI at all
+ - True if the script should be shown in UI if it's scelected in the scripts drowpdown
+ - script.AlwaysVisible if the script should be shown in UI at all times
+ """
+
return True
- # This is where the additional processing is implemented. The parameters include
- # self, the model object "p" (a StableDiffusionProcessing class, see
- # processing.py), and the parameters returned by the ui method.
- # Custom functions can be defined here, and additional libraries can be imported
- # to be used in processing. The return value should be a Processed object, which is
- # what is returned by the process_images method.
- def run(self, *args):
+ def run(self, p, *args):
+ """
+ This function is called if the script has been selected in the script dropdown.
+ It must do all processing and return the Processed object with results, same as
+ one returned by processing.process_images.
+
+ Usually the processing is done by calling the processing.process_images function.
+
+ args contains all values returned by components from ui()
+ """
+
raise NotImplementedError()
- # The description method is currently unused.
- # To add a description that appears when hovering over the title, amend the "titles"
- # dict in script.js to include the script title (returned by title) as a key, and
- # your description as the value.
+ def process(self, p, *args):
+ """
+ This function is called before processing begins for AlwaysVisible scripts.
+ scripts. You can modify the processing object (p) here, inject hooks, etc.
+ """
+
+ pass
+
def describe(self):
+ """unused"""
return ""
+current_basedir = paths.script_path
+
+
+def basedir():
+ """returns the base directory for the current script. For scripts in the main scripts directory,
+ this is the main directory (where webui.py resides), and for scripts in extensions directory
+ (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
+ """
+ return current_basedir
+
+
scripts_data = []
+ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
+
+
+def list_scripts(scriptdirname, extension):
+ scripts_list = []
+
+ basedir = os.path.join(paths.script_path, scriptdirname)
+ if os.path.exists(basedir):
+ for filename in sorted(os.listdir(basedir)):
+ scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
+
+ extdir = os.path.join(paths.script_path, "extensions")
+ if os.path.exists(extdir):
+ for dirname in sorted(os.listdir(extdir)):
+ dirpath = os.path.join(extdir, dirname)
+ scriptdirpath = os.path.join(dirpath, scriptdirname)
+
+ if not os.path.isdir(scriptdirpath):
+ continue
+
+ for filename in sorted(os.listdir(scriptdirpath)):
+ scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
+
+ scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+ return scripts_list
-def load_scripts(basedir):
- if not os.path.exists(basedir):
- return
- for filename in sorted(os.listdir(basedir)):
- path = os.path.join(basedir, filename)
+def list_files_with_name(filename):
+ res = []
- if not os.path.isfile(path):
+ dirs = [paths.script_path]
+
+ extdir = os.path.join(paths.script_path, "extensions")
+ if os.path.exists(extdir):
+ dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
+
+ for dirpath in dirs:
+ if not os.path.isdir(dirpath):
continue
+ path = os.path.join(dirpath, filename)
+ if os.path.isfile(filename):
+ res.append(path)
+
+ return res
+
+
+def load_scripts():
+ global current_basedir
+ scripts_data.clear()
+ script_callbacks.clear_callbacks()
+
+ scripts_list = list_scripts("scripts", ".py")
+
+ syspath = sys.path
+
+ for scriptfile in sorted(scripts_list):
try:
- with open(path, "r", encoding="utf8") as file:
+ if scriptfile.basedir != paths.script_path:
+ sys.path = [scriptfile.basedir] + sys.path
+ current_basedir = scriptfile.basedir
+
+ with open(scriptfile.path, "r", encoding="utf8") as file:
text = file.read()
from types import ModuleType
- compiled = compile(text, path, 'exec')
- module = ModuleType(filename)
+ compiled = compile(text, scriptfile.path, 'exec')
+ module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append((script_class, path))
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
except Exception:
- print(f"Error loading script: {filename}", file=sys.stderr)
+ print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ finally:
+ sys.path = syspath
+ current_basedir = paths.script_path
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
@@ -93,49 +185,84 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner:
def __init__(self):
self.scripts = []
+ self.selectable_scripts = []
+ self.alwayson_scripts = []
+ self.titles = []
+ self.infotext_fields = []
def setup_ui(self, is_img2img):
- for script_class, path in scripts_data:
+ for script_class, path, basedir in scripts_data:
script = script_class()
script.filename = path
- if not script.show(is_img2img):
- continue
+ visibility = script.show(is_img2img)
- self.scripts.append(script)
+ if visibility == AlwaysVisible:
+ self.scripts.append(script)
+ self.alwayson_scripts.append(script)
+ script.alwayson = True
- titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
+ elif visibility:
+ self.scripts.append(script)
+ self.selectable_scripts.append(script)
- dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
- inputs = [dropdown]
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
- for script in self.scripts:
+ inputs = [None]
+ inputs_alwayson = [True]
+
+ def create_script_ui(script, inputs, inputs_alwayson):
script.args_from = len(inputs)
script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
if controls is None:
- continue
+ return
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
- control.visible = False
+ if not script.alwayson:
+ control.visible = False
+
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
inputs += controls
+ inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs)
+ for script in self.alwayson_scripts:
+ with gr.Group():
+ create_script_ui(script, inputs, inputs_alwayson)
+
+ dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
+ dropdown.save_to_config = True
+ inputs[0] = dropdown
+
+ for script in self.selectable_scripts:
+ create_script_ui(script, inputs, inputs_alwayson)
+
def select_script(script_index):
- if 0 < script_index <= len(self.scripts):
- script = self.scripts[script_index-1]
+ if 0 < script_index <= len(self.selectable_scripts):
+ script = self.selectable_scripts[script_index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
- return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
+ return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
+ def init_field(title):
+ if title == 'None':
+ return
+ script_index = self.titles.index(title)
+ script = self.selectable_scripts[script_index]
+ for i in range(script.args_from, script.args_to):
+ inputs[i].visible = True
+
+ dropdown.init_field = init_field
dropdown.change(
fn=select_script,
inputs=[dropdown],
@@ -150,7 +277,7 @@ class ScriptRunner:
if script_index == 0:
return None
- script = self.scripts[script_index-1]
+ script = self.selectable_scripts[script_index-1]
if script is None:
return None
@@ -162,7 +289,16 @@ class ScriptRunner:
return processed
- def reload_sources(self):
+ def run_alwayson_scripts(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.process(p, *script_args)
+ except Exception:
+ print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
@@ -172,9 +308,12 @@ class ScriptRunner:
from types import ModuleType
- compiled = compile(text, filename, 'exec')
- module = ModuleType(script.filename)
- exec(compiled, module.__dict__)
+ module = cache.get(filename, None)
+ if module is None:
+ compiled = compile(text, filename, 'exec')
+ module = ModuleType(script.filename)
+ exec(compiled, module.__dict__)
+ cache[filename] = module
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
@@ -183,19 +322,22 @@ class ScriptRunner:
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
+
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
def reload_script_body_only():
- scripts_txt2img.reload_sources()
- scripts_img2img.reload_sources()
+ cache = {}
+ scripts_txt2img.reload_sources(cache)
+ scripts_img2img.reload_sources(cache)
-def reload_scripts(basedir):
+def reload_scripts():
global scripts_txt2img, scripts_img2img
- scripts_data.clear()
- load_scripts(basedir)
+ load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f12a9696..0f10828e 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -8,8 +8,9 @@ from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
+from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -18,35 +19,43 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)):
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
+ if not invokeAI_mps_available and shared.device.type == 'mps':
+ print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
+ print("Applying v1 cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ else:
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- print("Applying cross attention optimization.")
+ print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
def undo_optimizations():
+ from modules.hypernetworks import hypernetwork
+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def get_target_prompt_token_count(token_count):
- if token_count < 75:
- return 75
-
- return math.ceil(token_count / 10) * 10
+ return math.ceil(max(token_count, 1) / 75) * 75
class StableDiffusionModelHijack:
@@ -110,6 +119,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.tokenizer = wrapped.tokenizer
self.token_mults = {}
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
@@ -127,7 +138,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments):
- id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
if opts.enable_emphasis:
@@ -140,6 +150,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
fixes = []
remade_tokens = []
multipliers = []
+ last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
@@ -148,13 +159,33 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ if token == self.comma_token:
+ last_comma = len(remade_tokens)
+ elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+ last_comma += 1
+ reloc_tokens = remade_tokens[last_comma:]
+ reloc_mults = multipliers[last_comma:]
+
+ remade_tokens = remade_tokens[:last_comma]
+ length = len(remade_tokens)
+
+ rem = int(math.ceil(length / 75)) * 75 - length
+ remade_tokens += [id_end] * rem + reloc_tokens
+ multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
+ iteration = len(remade_tokens) // 75
+ if (len(remade_tokens) + emb_len) // 75 != iteration:
+ rem = (75 * (iteration + 1) - len(remade_tokens))
+ remade_tokens += [id_end] * rem
+ multipliers += [1.0] * rem
+ iteration += 1
+ fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
@@ -162,10 +193,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
- tokens_to_add = prompt_target_length - len(remade_tokens) + 1
+ tokens_to_add = prompt_target_length - len(remade_tokens)
- remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add
- multipliers = [1.0] + multipliers + [1.0] * tokens_to_add
+ remade_tokens = remade_tokens + [id_end] * tokens_to_add
+ multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
@@ -193,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@@ -250,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
@@ -262,37 +292,70 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
-
- if opts.use_old_emphasis_implementation:
+ use_old = opts.use_old_emphasis_implementation
+ if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
- self.hijack.fixes = hijack_fixes
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
- target_token_count = get_target_prompt_token_count(token_count) + 2
+ if use_old:
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
- position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76]
- position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
+ z = None
+ i = 0
+ while max(map(len, remade_batch_tokens)) != 0:
+ rem_tokens = [x[75:] for x in remade_batch_tokens]
+ rem_multipliers = [x[75:] for x in batch_multipliers]
- remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
- tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
+ self.hijack.fixes = []
+ for unfiltered in hijack_fixes:
+ fixes = []
+ for fix in unfiltered:
+ if fix[0] == i:
+ fixes.append(fix[1])
+ self.hijack.fixes.append(fixes)
+
+ tokens = []
+ multipliers = []
+ for j in range(len(remade_batch_tokens)):
+ if len(remade_batch_tokens[j]) > 0:
+ tokens.append(remade_batch_tokens[j][:75])
+ multipliers.append(batch_multipliers[j][:75])
+ else:
+ tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
+ multipliers.append([1.0] * 75)
- tmp = -opts.CLIP_ignore_last_layers
- if (opts.CLIP_ignore_last_layers == 0):
- outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
- z = outputs.last_hidden_state
- else:
- outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp)
- z = outputs.hidden_states[tmp]
+ z1 = self.process_tokens(tokens, multipliers)
+ z = z1 if z is None else torch.cat((z, z1), axis=-2)
+
+ remade_batch_tokens = rem_tokens
+ batch_multipliers = rem_multipliers
+ i += 1
+
+ return z
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
+ batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+
+ tokens = torch.asarray(remade_batch_tokens).to(device)
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+
+ if opts.CLIP_stop_at_last_layers > 1:
+ z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
+ else:
+ z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
+ batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
@@ -321,8 +384,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
new file mode 100644
index 00000000..fd92a335
--- /dev/null
+++ b/modules/sd_hijack_inpainting.py
@@ -0,0 +1,331 @@
+import torch
+
+from einops import repeat
+from omegaconf import ListConfig
+
+import ldm.models.diffusion.ddpm
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
+
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+
+# =================================================================================================
+# Monkey patch DDIMSampler methods from RunwayML repo directly.
+# Adapted from:
+# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
+# =================================================================================================
+@torch.no_grad()
+def sample_ddim(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list):
+ ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+@torch.no_grad()
+def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
+ for i in range(len(c[k]))
+ ]
+ else:
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+
+# =================================================================================================
+# Monkey patch PLMSSampler methods.
+# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
+# Adapted from:
+# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
+# =================================================================================================
+@torch.no_grad()
+def sample_plms(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list):
+ ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+
+@torch.no_grad()
+def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
+ for i in range(len(c[k]))
+ ]
+ else:
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
+
+# =================================================================================================
+# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
+# Adapted from:
+# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
+# =================================================================================================
+
+@torch.no_grad()
+def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ # todo: get null label from cond_stage_model
+ raise NotImplementedError()
+ c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
+ return c
+
+
+class LatentInpaintDiffusion(LatentDiffusion):
+ def __init__(
+ self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+ self.concat_keys = concat_keys
+
+
+def should_hijack_inpainting(checkpoint_info):
+ return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
+
+
+def do_inpainting_hijack():
+ ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
+ ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
+
+ ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
+ ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
+
+ ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
+ ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms \ No newline at end of file
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 634fb4b2..98123fbf 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,6 +1,7 @@
import math
import sys
import traceback
+import importlib
import torch
from torch import einsum
@@ -9,12 +10,12 @@ from ldm.util import default
from einops import rearrange
from modules import shared
+from modules.hypernetworks import hypernetwork
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
import xformers.ops
- import functorch
- xformers._is_functorch_available = True
shared.xformers_available = True
except Exception:
print("Cannot import xformers", file=sys.stderr)
@@ -28,16 +29,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
- del context, x
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+ del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
@@ -61,22 +56,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
return self.to_out(r2)
-# taken from https://github.com/Doggettx/stable-diffusion
+# taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
k_in *= self.scale
@@ -128,18 +117,111 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
+
+def check_for_psutil():
+ try:
+ spec = importlib.util.find_spec('psutil')
+ return spec is not None
+ except ModuleNotFoundError:
+ return False
+
+invokeAI_mps_available = check_for_psutil()
+
+# -- Taken from https://github.com/invoke-ai/InvokeAI --
+if invokeAI_mps_available:
+ import psutil
+ mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+
+def einsum_op_compvis(q, k, v):
+ s = einsum('b i d, b j d -> b i j', q, k)
+ s = s.softmax(dim=-1, dtype=s.dtype)
+ return einsum('b i j, b j d -> b i d', s, v)
+
+def einsum_op_slice_0(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[0], slice_size):
+ end = i + slice_size
+ r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
+ return r
+
+def einsum_op_slice_1(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
+ return r
+
+def einsum_op_mps_v1(q, k, v):
+ if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ return einsum_op_compvis(q, k, v)
+ else:
+ slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ return einsum_op_slice_1(q, k, v, slice_size)
+
+def einsum_op_mps_v2(q, k, v):
+ if mem_total_gb > 8 and q.shape[1] <= 4096:
+ return einsum_op_compvis(q, k, v)
+ else:
+ return einsum_op_slice_0(q, k, v, 1)
+
+def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
+ size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
+ if size_mb <= max_tensor_mb:
+ return einsum_op_compvis(q, k, v)
+ div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
+ if div <= q.shape[0]:
+ return einsum_op_slice_0(q, k, v, q.shape[0] // div)
+ return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
+
+def einsum_op_cuda(q, k, v):
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ # Divide factor of safety as there's copying and fragmentation
+ return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
+
+def einsum_op(q, k, v):
+ if q.device.type == 'cuda':
+ return einsum_op_cuda(q, k, v)
+
+ if q.device.type == 'mps':
+ if mem_total_gb >= 32:
+ return einsum_op_mps_v1(q, k, v)
+ return einsum_op_mps_v2(q, k, v)
+
+ # Smaller slices are faster due to L2/L3/SLC caches.
+ # Tested on i7 with 8MB L3 cache.
+ return einsum_op_tensor_mem(q, k, v, 32)
+
+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k) * self.scale
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ r = einsum_op(q, k, v)
+ return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
+
+# -- End of code from https://github.com/invoke-ai/InvokeAI --
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
@@ -214,10 +296,16 @@ def xformers_attnblock_forward(self, x):
try:
h_ = x
h_ = self.norm(h_)
- q1 = self.q(h_).contiguous()
- k1 = self.k(h_).contiguous()
- v = self.v(h_).contiguous()
- out = xformers.ops.memory_efficient_attention(q1, k1, v)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = xformers.ops.memory_efficient_attention(q, k, v)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
except NotImplementedError:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e63d3c29..e697bb72 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,4 +1,4 @@
-import glob
+import collections
import os.path
import sys
from collections import namedtuple
@@ -7,19 +7,21 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices
+from modules import shared, modelloader, devices, script_callbacks
from modules.paths import models_path
+from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
checkpoints_list = {}
+checkpoints_loaded = collections.OrderedDict()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
- from transformers import logging
+ from transformers import logging, CLIPModel
logging.set_verbosity_error()
except Exception:
@@ -121,56 +123,110 @@ def select_checkpoint():
return checkpoint_info
+chckpoint_dict_replacements = {
+ 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
+ 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
+ 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
+}
+
+
+def transform_checkpoint_dict_key(k):
+ for text, replacement in chckpoint_dict_replacements.items():
+ if k.startswith(text):
+ k = replacement + k[len(text):]
+
+ return k
+
+
def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd:
- return pl_sd["state_dict"]
+ pl_sd = pl_sd["state_dict"]
+
+ sd = {}
+ for k, v in pl_sd.items():
+ new_key = transform_checkpoint_dict_key(k)
+
+ if new_key is not None:
+ sd[new_key] = v
+
+ pl_sd.clear()
+ pl_sd.update(sd)
return pl_sd
+vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+
+
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+ if checkpoint_info not in checkpoints_loaded:
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+
+ pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
- pl_sd = torch.load(checkpoint_file, map_location="cpu")
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
+ sd = get_state_dict_from_checkpoint(pl_sd)
+ missing, extra = model.load_state_dict(sd, strict=False)
- sd = get_state_dict_from_checkpoint(pl_sd)
+ if shared.cmd_opts.opt_channelslast:
+ model.to(memory_format=torch.channels_last)
- model.load_state_dict(sd, strict=False)
+ if not shared.cmd_opts.no_half:
+ model.half()
- if shared.cmd_opts.opt_channelslast:
- model.to(memory_format=torch.channels_last)
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
- if not shared.cmd_opts.no_half:
- model.half()
+ vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
- devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+ if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
+ vae_file = shared.cmd_opts.vae_path
- vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
- if os.path.exists(vae_file):
- print(f"Loading VAE weights from: {vae_file}")
- vae_ckpt = torch.load(vae_file, map_location="cpu")
- vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
+ if os.path.exists(vae_file):
+ print(f"Loading VAE weights from: {vae_file}")
+ vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
+ vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ model.first_stage_model.load_state_dict(vae_dict)
- model.first_stage_model.load_state_dict(vae_dict)
+ model.first_stage_model.to(devices.dtype_vae)
+
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False) # LRU
+ else:
+ print(f"Loading weights [{sd_model_hash}] from cache")
+ checkpoints_loaded.move_to_end(checkpoint_info)
+ model.load_state_dict(checkpoints_loaded[checkpoint_info])
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
-def load_model():
+def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
- checkpoint_info = select_checkpoint()
+ checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
+
+ if should_hijack_inpainting(checkpoint_info):
+ # Hardcoded config for now...
+ sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
+ sd_config.model.params.use_ema = False
+ sd_config.model.params.conditioning_key = "hybrid"
+ sd_config.model.params.unet_config.params.in_channels = 9
+
+ # Create a "fake" config with a different name so that we know to unload it when switching models.
+ checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+
+ do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -182,6 +238,9 @@ def load_model():
sd_hijack.model_hijack.hijack(sd_model)
sd_model.eval()
+ shared.sd_model = sd_model
+
+ script_callbacks.model_loaded_callback(sd_model)
print(f"Model loaded.")
return sd_model
@@ -194,8 +253,9 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
- shared.sd_model = load_model()
+ if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ checkpoints_loaded.clear()
+ load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -208,6 +268,7 @@ def reload_model_weights(sd_model, info=None):
load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 6e743f7e..0b408a70 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,7 +7,7 @@ import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser
+from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -57,7 +57,7 @@ def set_samplers():
global samplers, samplers_for_img2img
hidden = set(opts.hide_samplers)
- hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
+ hidden_img2img = set(opts.hide_samplers + ['PLMS'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
@@ -71,6 +71,7 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
@@ -82,14 +83,22 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-def sample_to_image(samples):
- x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
+def single_sample_to_image(sample):
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
+def sample_to_image(samples):
+ return single_sample_to_image(samples[0])
+
+
+def samples_to_image_grid(samples):
+ return images.image_grid([single_sample_to_image(sample) for sample in samples])
+
+
def store_latent(decoded):
state.current_latent = decoded
@@ -98,25 +107,8 @@ def store_latent(decoded):
shared.state.current_image = sample_to_image(decoded)
-
-def extended_tdqm(sequence, *args, desc=None, **kwargs):
- state.sampling_steps = len(sequence)
- state.sampling_step = 0
-
- seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
-
- for x in seq:
- if state.interrupted or state.skipped:
- break
-
- yield x
-
- state.sampling_step += 1
- shared.total_tqdm.update()
-
-
-ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
-ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
+class InterruptedException(BaseException):
+ pass
class VanillaStableDiffusionSampler:
@@ -128,14 +120,40 @@ class VanillaStableDiffusionSampler:
self.init_latent = None
self.sampler_noises = None
self.step = 0
+ self.stop_at = None
self.eta = None
self.default_eta = 0.0
self.config = None
+ self.last_latent = None
+
+ self.conditioning_key = sd_model.model.conditioning_key
def number_of_needed_noises(self, p):
return 0
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except InterruptedException:
+ return self.last_latent
+
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ if state.interrupted or state.skipped:
+ raise InterruptedException
+
+ if self.stop_at is not None and self.step > self.stop_at:
+ raise InterruptedException
+
+ # Have to unwrap the inpainting conditioning here to perform pre-processing
+ image_conditioning = None
+ if isinstance(cond, dict):
+ image_conditioning = cond["c_concat"][0]
+ cond = cond["c_crossattn"][0]
+ unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
@@ -156,14 +174,25 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
+ # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
+ # Note that they need to be lists because it just concatenates them later.
+ if image_conditioning is not None:
+ cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
- store_latent(self.init_latent * self.mask + self.nmask * res[1])
+ self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
else:
- store_latent(res[1])
+ self.last_latent = res[1]
+
+ store_latent(self.last_latent)
self.step += 1
+ state.sampling_step = self.step
+ shared.total_tqdm.update()
+
return res
def initialize(self, p):
@@ -176,7 +205,7 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
self.initialize(p)
@@ -190,25 +219,38 @@ class VanillaStableDiffusionSampler:
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
+ self.last_latent = x
self.step = 0
- samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+
+ samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
+ self.last_latent = x
self.step = 0
steps = steps or p.steps
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
# existing code fails with certain step counts, like 9
try:
- samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
except Exception:
- samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
@@ -222,7 +264,10 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
- def forward(self, x, sigma, uncond, cond, cond_scale):
+ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
+ if state.interrupted or state.skipped:
+ raise InterruptedException
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
@@ -230,28 +275,29 @@ class CFGDenoiser(torch.nn.Module):
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
+ x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
@@ -268,25 +314,6 @@ class CFGDenoiser(torch.nn.Module):
return denoised
-def extended_trange(sampler, count, *args, **kwargs):
- state.sampling_steps = count
- state.sampling_step = 0
-
- seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
-
- for x in seq:
- if state.interrupted or state.skipped:
- break
-
- if sampler.stop_at is not None and x > sampler.stop_at:
- break
-
- yield x
-
- state.sampling_step += 1
- shared.total_tqdm.update()
-
-
class TorchHijack:
def __init__(self, kdiff_sampler):
self.kdiff_sampler = kdiff_sampler
@@ -314,9 +341,30 @@ class KDiffusionSampler:
self.eta = None
self.default_eta = 1.0
self.config = None
+ self.last_latent = None
+
+ self.conditioning_key = sd_model.model.conditioning_key
def callback_state(self, d):
- store_latent(d["denoised"])
+ step = d['i']
+ latent = d["denoised"]
+ store_latent(latent)
+ self.last_latent = latent
+
+ if self.stop_at is not None and step > self.stop_at:
+ raise InterruptedException
+
+ state.sampling_step = step
+ shared.total_tqdm.update()
+
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except InterruptedException:
+ return self.last_latent
def number_of_needed_noises(self, p):
return p.steps
@@ -339,9 +387,6 @@ class KDiffusionSampler:
self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral
- if hasattr(k_diffusion.sampling, 'trange'):
- k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
-
if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
@@ -355,7 +400,7 @@ class KDiffusionSampler:
return extra_params_kwargs
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override:
@@ -365,18 +410,35 @@ class KDiffusionSampler:
else:
sigmas = self.model_wrap.get_sigmas(steps)
- noise = noise * sigmas[steps - t_enc - 1]
- xi = x + noise
-
- extra_params_kwargs = self.initialize(p)
-
sigma_sched = sigmas[steps - t_enc - 1:]
+ xi = x + noise * sigma_sched[0]
+
+ extra_params_kwargs = self.initialize(p)
+ if 'sigma_min' in inspect.signature(self.func).parameters:
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
+ extra_params_kwargs['sigma_min'] = sigma_sched[-2]
+ if 'sigma_max' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_max'] = sigma_sched[0]
+ if 'n' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['n'] = len(sigma_sched) - 1
+ if 'sigma_sched' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_sched'] = sigma_sched
+ if 'sigmas' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
- return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
if p.sampler_noise_scheduler_override:
@@ -396,6 +458,14 @@ class KDiffusionSampler:
extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
- samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+
+ self.last_latent = x
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
return samples
diff --git a/modules/shared.py b/modules/shared.py
index 6ecc2503..b55371d3 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -3,6 +3,7 @@ import datetime
import json
import os
import sys
+from collections import OrderedDict
import gradio as gr
import tqdm
@@ -13,7 +14,8 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, hypernetwork
+from modules import sd_samplers, sd_models, localization
+from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -25,16 +27,22 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
+parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
+parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
+parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
+parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
@@ -46,15 +54,17 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
-parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
+parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
-parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
+parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
+parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -62,26 +72,50 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
+parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
-
+parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
+parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
+parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
+parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
+parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False)
cmd_opts = parser.parse_args()
-
-devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
+restricted_opts = [
+ "samples_filename_pattern",
+ "outdir_samples",
+ "outdir_txt2img_samples",
+ "outdir_img2img_samples",
+ "outdir_extras_samples",
+ "outdir_grids",
+ "outdir_txt2img_grids",
+ "outdir_save",
+]
+
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
device = devices.device
+weight_load_location = None if cmd_opts.lowram else "cpu"
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file
-hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks'))
+os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
+hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
+def reload_hypernetworks():
+ global hypernetworks
+
+ hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+ hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
class State:
skipped = False
@@ -107,7 +141,7 @@ class State:
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
-
+
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
@@ -123,6 +157,8 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
+localization.list_localizations(cmd_opts.localizations_dir)
+
def realesrgan_models_names():
import modules.realesrgan_model
@@ -130,13 +166,14 @@ def realesrgan_models_names():
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
- self.section = None
+ self.section = section
+ self.refresh = refresh
def options_section(section_identifier, options_dict):
@@ -159,6 +196,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
+ "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
@@ -169,6 +207,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
+ "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -198,6 +237,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
+ "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -212,9 +252,19 @@ options_templates.update(options_section(('system', "System"), {
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
}))
+options_templates.update(options_section(('training', "Training"), {
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
+ "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
+ "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+ "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
+ "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
+}))
+
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
- "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
+ "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
@@ -222,31 +272,41 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
+ "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
- 'CLIP_ignore_last_layers': OptionInfo(0, "Ignore last layers of CLIP model", gr.Slider, {"minimum": 0, "maximum": 5, "step": 1}),
+ 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
+ "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
- "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
+ "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
+ "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
}))
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
+ "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
+ 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+ 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
@@ -257,6 +317,16 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
+}))
+
+options_templates.update(options_section(('images-history', "Images Browser"), {
+ #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
+ "images_history_preload": OptionInfo(False, "Preload images at startup"),
+ "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
+ "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
+ "images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
+
}))
@@ -316,10 +386,26 @@ class Options:
item = self.data_labels.get(key)
item.onchange = func
+ func()
+
def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
+ def add_option(self, key, info):
+ self.data_labels[key] = info
+
+ def reorder(self):
+ """reorder settings so that all items related to section always go together"""
+
+ section_ids = {}
+ settings_items = self.data_labels.items()
+ for k, item in settings_items:
+ if item.section not in section_ids:
+ section_ids[item.section] = len(section_ids)
+
+ self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+
opts = Options()
if os.path.exists(config_filename):
@@ -329,6 +415,8 @@ sd_upscalers = []
sd_model = None
+clip_model = None
+
progress_print_out = sys.stdout
diff --git a/modules/styles.py b/modules/styles.py
index d44dfc1a..3bf5c5b6 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -45,7 +45,7 @@ class StyleDatabase:
if not os.path.exists(path):
return
- with open(path, "r", encoding="utf8", newline='') as file:
+ with open(path, "r", encoding="utf-8-sig", newline='') as file:
reader = csv.DictReader(file)
for row in reader:
# Support loading old CSV format with "name, text"-columns
@@ -79,7 +79,7 @@ class StyleDatabase:
def save_styles(self, path: str) -> None:
# Write to temporary file first, so we don't nuke the file if something goes wrong
fd, temp_path = tempfile.mkstemp(".csv")
- with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
+ with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index fbd11f84..baa02e3d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -10,6 +10,7 @@ from tqdm import tqdm
from modules import modelloader
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
+from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
precision_scope = (
@@ -57,22 +58,42 @@ class UpscalerSwinIR(Upscaler):
filename = path
if filename is None or not os.path.exists(filename):
return None
- model = net(
+ if filename.endswith(".v2.pth"):
+ model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
+ depths=[6, 6, 6, 6, 6, 6],
+ embed_dim=180,
+ num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
- resi_connection="3conv",
- )
+ resi_connection="1conv",
+ )
+ params = None
+ else:
+ model = net(
+ upscale=scale,
+ in_chans=3,
+ img_size=64,
+ window_size=8,
+ img_range=1.0,
+ depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
+ embed_dim=240,
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
+ mlp_ratio=2,
+ upsampler="nearest+conv",
+ resi_connection="3conv",
+ )
+ params = "params_ema"
pretrained_model = torch.load(filename)
- model.load_state_dict(pretrained_model["params_ema"], strict=True)
+ if params is not None:
+ model.load_state_dict(pretrained_model[params], strict=True)
+ else:
+ model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half:
model = model.half()
return model
diff --git a/modules/swinir_model_arch_v2.py b/modules/swinir_model_arch_v2.py
new file mode 100644
index 00000000..0e28ae6e
--- /dev/null
+++ b/modules/swinir_model_arch_v2.py
@@ -0,0 +1,1017 @@
+# -----------------------------------------------------------------------------------
+# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
+# Written by Conde and Choi et al.
+# -----------------------------------------------------------------------------------
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
+ pretrained_window_size=[0, 0]):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
+
+ # mlp to generate continuous relative position bias
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_h,
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
+ else:
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+
+ self.register_buffer("relative_coords_table", relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ # cosine attention
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
+ attn = attn * logit_scale
+
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pretrained_window_size (int): Window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ pretrained_window_size=to_2tuple(pretrained_window_size))
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ attn_mask = self.calculate_mask(self.input_resolution)
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def calculate_mask(self, x_size):
+ # calculate attention mask for SW-MSA
+ H, W = x_size
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x, x_size):
+ H, W = x_size
+ B, L, C = x.shape
+ #assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+ if self.input_resolution == x_size:
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+ else:
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+ x = shortcut + self.drop_path(self.norm1(x))
+
+ # FFN
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.reduction(x)
+ x = self.norm(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ flops += H * W * self.dim // 2
+ return flops
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ pretrained_window_size (int): Local window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ pretrained_window_size=0):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ pretrained_window_size=pretrained_window_size)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, x_size)
+ else:
+ x = blk(x, x_size)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+ def _init_respostnorm(self):
+ for blk in self.blocks:
+ nn.init.constant_(blk.norm1.bias, 0)
+ nn.init.constant_(blk.norm1.weight, 0)
+ nn.init.constant_(blk.norm2.bias, 0)
+ nn.init.constant_(blk.norm2.weight, 0)
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ # assert H == self.img_size[0] and W == self.img_size[1],
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class RSTB(nn.Module):
+ """Residual Swin Transformer Block (RSTB).
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ img_size=224, patch_size=4, resi_connection='1conv'):
+ super(RSTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ H, W = self.input_resolution
+ flops += H * W * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ B, HW, C = x.shape
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+class Upsample_hf(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample_hf, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.num_feat * 3 * 9
+ return flops
+
+
+
+class Swin2SR(nn.Module):
+ r""" Swin2SR
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 64
+ patch_size (int | tuple(int)): Patch size. Default: 1
+ in_chans (int): Number of input image channels. Default: 3
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+ img_range: Image range. 1. or 255.
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+ """
+
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ window_size=7, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
+ **kwargs):
+ super(Swin2SR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ if in_chans == 3:
+ rgb_mean = (0.4488, 0.4371, 0.4040)
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+ self.window_size = window_size
+
+ #####################################################################################################
+ ################################### 1, shallow feature extraction ###################################
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ #####################################################################################################
+ ################################### 2, deep feature extraction ######################################
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Swin Transformer blocks (RSTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers.append(layer)
+
+ if self.upsampler == 'pixelshuffle_hf':
+ self.layers_hf = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers_hf.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ #####################################################################################################
+ ################################ 3, high quality image reconstruction ################################
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffle_aux':
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_after_aux = nn.Sequential(
+ nn.Conv2d(3, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffle_hf':
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ self.conv_before_upsample_hf = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ assert self.upscale == 4, 'only support x4 now.'
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def check_image_size(self, x):
+ _, _, h, w = x.size()
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ return x
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward_features_hf(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers_hf:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ H, W = x.shape[2:]
+ x = self.check_image_size(x)
+
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffle_aux':
+ bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
+ bicubic = self.conv_bicubic(bicubic)
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
+ x = self.conv_after_aux(aux)
+ x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
+ x = self.conv_last(x)
+ aux = aux / self.img_range + self.mean
+ elif self.upsampler == 'pixelshuffle_hf':
+ # for classical SR with HF
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x_before = self.conv_before_upsample(x)
+ x_out = self.conv_last(self.upsample(x_before))
+
+ x_hf = self.conv_first_hf(x_before)
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
+ x_hf = self.conv_before_upsample_hf(x_hf)
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
+ x = x_out + x_hf
+ x_hf = x_hf / self.img_range + self.mean
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+ if self.upsampler == "pixelshuffle_aux":
+ return x[:, :, :H*self.upscale, :W*self.upscale], aux
+
+ elif self.upsampler == "pixelshuffle_hf":
+ x_out = x_out / self.img_range + self.mean
+ return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
+
+ else:
+ return x[:, :, :H*self.upscale, :W*self.upscale]
+
+ def flops(self):
+ flops = 0
+ H, W = self.patches_resolution
+ flops += H * W * 3 * self.embed_dim * 9
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
+ flops += self.upsample.flops()
+ return flops
+
+
+if __name__ == '__main__':
+ upscale = 4
+ window_size = 8
+ height = (1024 // upscale // window_size + 1) * window_size
+ width = (720 // upscale // window_size + 1) * window_size
+ model = Swin2SR(upscale=2, img_size=(height, width),
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
+ print(model)
+ print(height, width, model.flops() / 1e9)
+
+ x = torch.randn((1, 3, height, width))
+ x = model(x)
+ print(x.shape) \ No newline at end of file
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7c44ea5b..5b1c5002 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -8,18 +8,28 @@ from torchvision import transforms
import random
import tqdm
-from modules import devices
+from modules import devices, shared
import re
-re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
+re_numbers_at_start = re.compile(r"^[-\d]+\s*")
+
+
+class DatasetEntry:
+ def __init__(self, filename=None, latent=None, filename_text=None):
+ self.filename = filename
+ self.latent = latent
+ self.filename_text = filename_text
+ self.cond = None
+ self.cond_text = None
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
- self.size = size
+ self.batch_size = batch_size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -33,16 +43,28 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
+ cond_model = shared.sd_model.cond_stage_model
+
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
- image = Image.open(path)
- image = image.convert('RGB')
- image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
+ text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
- filename_tokens = os.path.splitext(filename)[0]
- filename_tokens = re_tag.findall(filename_tokens)
+
+ if os.path.exists(text_filename):
+ with open(text_filename, "r", encoding="utf8") as file:
+ filename_text = file.read()
+ else:
+ filename_text = os.path.splitext(filename)[0]
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
+ if re_word:
+ tokens = re_word.findall(filename_text)
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@@ -53,29 +75,47 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens))
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
- self.length = len(self.dataset) * repeats
+ if include_cond:
+ entry.cond_text = self.create_text(filename_text)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
- self.initial_indexes = np.arange(self.length) % len(self.dataset)
+ self.dataset.append(entry)
+
+ assert len(self.dataset) > 0, "No images have been found in the dataset."
+ self.length = len(self.dataset) * repeats // batch_size
+
+ self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
+
+ def create_text(self, filename_text):
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", filename_text)
+ return text
def __len__(self):
return self.length
def __getitem__(self, i):
- if i % len(self.dataset) == 0:
- self.shuffle()
+ res = []
- index = self.indexes[i % len(self.indexes)]
- x, filename_tokens = self.dataset[index]
+ for j in range(self.batch_size):
+ position = i * self.batch_size + j
+ if position % len(self.indexes) == 0:
+ self.shuffle()
- text = random.choice(self.lines)
- text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ index = self.indexes[position % len(self.indexes)]
+ entry = self.dataset[index]
+
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
+
+ res.append(entry)
- return x, text
+ return res
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
new file mode 100644
index 00000000..ea653806
--- /dev/null
+++ b/modules/textual_inversion/image_embedding.py
@@ -0,0 +1,220 @@
+import base64
+import json
+import numpy as np
+import zlib
+from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
+from fonts.ttf import Roboto
+import torch
+from modules.shared import opts
+
+
+class EmbeddingEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, torch.Tensor):
+ return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
+ return json.JSONEncoder.default(self, obj)
+
+
+class EmbeddingDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+
+ def object_hook(self, d):
+ if 'TORCHTENSOR' in d:
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
+ return d
+
+
+def embedding_to_b64(data):
+ d = json.dumps(data, cls=EmbeddingEncoder)
+ return base64.b64encode(d.encode())
+
+
+def embedding_from_b64(data):
+ d = base64.b64decode(data)
+ return json.loads(d, cls=EmbeddingDecoder)
+
+
+def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
+ while True:
+ seed = (a * seed + c) % m
+ yield seed % 255
+
+
+def xor_block(block):
+ g = lcg()
+ randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
+ return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
+
+
+def style_block(block, sequence):
+ im = Image.new('RGB', (block.shape[1], block.shape[0]))
+ draw = ImageDraw.Draw(im)
+ i = 0
+ for x in range(-6, im.size[0], 8):
+ for yi, y in enumerate(range(-6, im.size[1], 8)):
+ offset = 0
+ if yi % 2 == 0:
+ offset = 4
+ shade = sequence[i % len(sequence)]
+ i += 1
+ draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
+
+ fg = np.array(im).astype(np.uint8) & 0xF0
+
+ return block ^ fg
+
+
+def insert_image_data_embed(image, data):
+ d = 3
+ data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
+ data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
+ data_np_high = data_np_ >> 4
+ data_np_low = data_np_ & 0x0F
+
+ h = image.size[1]
+ next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
+ next_size = next_size + ((h*d)-(next_size % (h*d)))
+
+ data_np_low.resize(next_size)
+ data_np_low = data_np_low.reshape((h, -1, d))
+
+ data_np_high.resize(next_size)
+ data_np_high = data_np_high.reshape((h, -1, d))
+
+ edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
+ edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
+
+ data_np_low = style_block(data_np_low, sequence=edge_style)
+ data_np_low = xor_block(data_np_low)
+ data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
+ data_np_high = xor_block(data_np_high)
+
+ im_low = Image.fromarray(data_np_low, mode='RGB')
+ im_high = Image.fromarray(data_np_high, mode='RGB')
+
+ background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
+ background.paste(im_low, (0, 0))
+ background.paste(image, (im_low.size[0]+1, 0))
+ background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
+
+ return background
+
+
+def crop_black(img, tol=0):
+ mask = (img > tol).all(2)
+ mask0, mask1 = mask.any(0), mask.any(1)
+ col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
+ row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
+ return img[row_start:row_end, col_start:col_end]
+
+
+def extract_image_data_embed(image):
+ d = 3
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
+ black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
+ if black_cols[0].shape[0] < 2:
+ print('No Image data blocks found.')
+ return None
+
+ data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
+ data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
+
+ data_block_lower = xor_block(data_block_lower)
+ data_block_upper = xor_block(data_block_upper)
+
+ data_block = (data_block_upper << 4) | (data_block_lower)
+ data_block = data_block.flatten().tobytes()
+
+ data = zlib.decompress(data_block)
+ return json.loads(data, cls=EmbeddingDecoder)
+
+
+def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
+ from math import cos
+
+ image = srcimage.copy()
+ fontsize = 32
+ if textfont is None:
+ try:
+ textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
+ textfont = opts.font or Roboto
+ except Exception:
+ textfont = Roboto
+
+ factor = 1.5
+ gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
+ for y in range(image.size[1]):
+ mag = 1-cos(y/image.size[1]*factor)
+ mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
+ gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
+ image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
+
+ draw = ImageDraw.Draw(image)
+
+ font = ImageFont.truetype(textfont, fontsize)
+ padding = 10
+
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
+ font = ImageFont.truetype(textfont, fontsize)
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
+
+ _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
+ fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
+ fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
+ fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+
+ font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
+
+ draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
+
+ return image
+
+
+if __name__ == '__main__':
+
+ testEmbed = Image.open('test_embedding.png')
+ data = extract_image_data_embed(testEmbed)
+ assert data is not None
+
+ data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
+ assert data is not None
+
+ image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
+ cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
+
+ test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
+
+ embedded_image = insert_image_data_embed(cap_image, test_embed)
+
+ retrived_embed = extract_image_data_embed(embedded_image)
+
+ assert str(retrived_embed) == str(test_embed)
+
+ embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
+
+ assert embedded_image == embedded_image2
+
+ g = lcg()
+ shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
+
+ reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
+ 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
+ 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
+ 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
+ 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
+ 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
+ 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
+ 204, 86, 73, 222, 44, 198, 118, 240, 97]
+
+ assert shared_random == reference_random
+
+ hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
+
+ assert 12731374 == hunna_kay_random_sum
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
new file mode 100644
index 00000000..2062726a
--- /dev/null
+++ b/modules/textual_inversion/learn_schedule.py
@@ -0,0 +1,69 @@
+import tqdm
+
+
+class LearnScheduleIterator:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ """
+
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number <= self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index f1c002a2..33eaddb6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,16 +1,46 @@
import os
from PIL import Image, ImageOps
+import math
import platform
import sys
import tqdm
+import time
from modules import shared, images
+from modules.shared import opts, cmd_opts
+if cmd_opts.deepdanbooru:
+ import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
- size = 512
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
+ try:
+ if process_caption:
+ shared.interrogator.load()
+
+ if process_caption_deepbooru:
+ db_opts = deepbooru.create_deepbooru_opts()
+ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
+ deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
+
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
+
+ finally:
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
+
+ if process_caption_deepbooru:
+ deepbooru.release_process()
+
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
+ width = process_width
+ height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
+ split_threshold = max(0.0, min(1.0, split_threshold))
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination'
@@ -21,84 +51,97 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- if process_caption:
- shared.interrogator.load()
+ def save_pic_with_caption(image, index, existing_caption=None):
+ caption = ""
- def save_pic_with_caption(image, index):
if process_caption:
- caption = "-" + shared.interrogator.generate_caption(image)
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- else:
- caption = filename
- caption = os.path.splitext(caption)[0]
- caption = os.path.basename(caption)
+ caption += shared.interrogator.generate_caption(image)
+
+ if process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.get_tags_from_process(image)
+
+ filename_part = filename
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{subindex[0]}-{filename_part}"
+ image.save(os.path.join(dst, f"{basename}.png"))
+
+ if preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+
+ caption = caption.strip()
+
+ if len(caption) > 0:
+ with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
- image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
- def save_pic(image, index):
- save_pic_with_caption(image, index)
+ def save_pic(image, index, existing_caption=None):
+ save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index)
+ save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
+
+ def split_pic(image, inverse_xy):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
- img = Image.open(filename).convert("RGB")
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
+
+ existing_caption = None
+ existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+ if os.path.exists(existing_caption_filename):
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
+ existing_caption = file.read()
if shared.state.interrupted:
break
- ratio = img.height / img.width
- is_tall = ratio > 1.35
- is_wide = ratio < 1 / 1.35
-
- if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
-
- top = img.crop((0, 0, size, size))
- save_pic(top, index)
-
- bot = img.crop((0, img.height - size, size, img.height))
- save_pic(bot, index)
- elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
-
- left = img.crop((0, 0, size, size))
- save_pic(left, index)
+ if img.height > img.width:
+ ratio = (img.width * height) / (img.height * width)
+ inverse_xy = False
+ else:
+ ratio = (img.height * width) / (img.width * height)
+ inverse_xy = True
- right = img.crop((img.width - size, 0, img.width, size))
- save_pic(right, index)
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
+ for splitted in split_pic(img, inverse_xy):
+ save_pic(splitted, index, existing_caption=existing_caption)
else:
- img = images.resize_image(1, img, size, size)
- save_pic(img, index)
+ img = images.resize_image(1, img, width, height)
+ save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob()
-
- if process_caption:
- shared.interrogator.send_blip_to_ram()
-
-def sanitize_caption(base_path, original_caption, suffix):
- operating_system = platform.system().lower()
- if (operating_system == "windows"):
- invalid_path_characters = "\\/:*?\"<>|"
- max_path_length = 259
- else:
- invalid_path_characters = "/" #linux/macos
- max_path_length = 1023
- caption = original_caption
- for invalid_character in invalid_path_characters:
- caption = caption.replace(invalid_character, "")
- fixed_path_length = len(base_path) + len(suffix)
- if fixed_path_length + len(caption) <= max_path_length:
- return caption
- caption_tokens = caption.split()
- new_caption = ""
- for token in caption_tokens:
- last_caption = new_caption
- new_caption = new_caption + token + " "
- if (len(new_caption) + fixed_path_length - 1 > max_path_length):
- break
- print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
- return last_caption.strip()
diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png
new file mode 100644
index 00000000..07e2d9af
--- /dev/null
+++ b/modules/textual_inversion/test_embedding.png
Binary files differ
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cd9f3498..529ed3e2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,11 +6,17 @@ import torch
import tqdm
import html
import datetime
+import csv
+from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
+ insert_image_data_embed, extract_image_data_embed,
+ caption_image_overlay)
class Embedding:
def __init__(self, vec, name, step=None):
@@ -80,7 +86,18 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = torch.load(path, map_location="cpu")
+ data = []
+
+ if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ embed_image = Image.open(path)
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
+ else:
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ else:
+ data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
@@ -120,6 +137,7 @@ class EmbeddingDatabase:
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+ print("Embeddings:", ', '.join(self.word_embeddings.keys()))
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
@@ -135,7 +153,7 @@ class EmbeddingDatabase:
return None, None
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
@@ -147,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
@@ -156,7 +175,33 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+def write_loss(log_directory, filename, step, epoch_len, values):
+ if shared.opts.training_write_csv_every == 0:
+ return
+
+ if step % shared.opts.training_write_csv_every != 0:
+ return
+
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
+
+ with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ epoch = step // epoch_len
+ epoch_step = step - epoch * epoch_len
+
+ csv_writer.writerow({
+ "step": step + 1,
+ "epoch": epoch + 1,
+ "epoch_step": epoch_step + 1,
+ **values,
+ })
+
+
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -178,43 +223,51 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
else:
images_dir = None
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
+
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
+ embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text) in pbar:
+ for i, entries in pbar:
embedding.step = i + ititial_step
- if embedding.step > steps:
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- c = cond_model([text])
-
- x = x.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
@@ -223,30 +276,83 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
loss.backward()
optimizer.step()
- pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step - (epoch_num * len(ds)) + 1
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=text,
- steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
+ do_not_reload_embeddings=True,
)
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
processed = processing.process_images(p)
image = processed.images[0]
shared.state.current_image = image
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
+
+ title = "<{}>".format(data.get('name', '???'))
+
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
+
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, embedding.step)
+
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
+
image.save(last_saved_image)
- last_saved_image += f", prompt: {text}"
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
@@ -254,7 +360,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
@@ -268,4 +374,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.save(filename)
return embedding, filename
-
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index f19ac5e0..e712284d 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
-def create_embedding(name, initialization_text, nvpt):
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -23,6 +23,8 @@ def preprocess(*args):
def train_embedding(*args):
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+
try:
sd_hijack.undo_optimizations()
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e985242b..c9d5a090 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,12 +1,13 @@
import modules.scripts
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
+ StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -30,10 +31,14 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
restore_faces=restore_faces,
tiling=tiling,
enable_hr=enable_hr,
- scale_latent=scale_latent if enable_hr else None,
denoising_strength=denoising_strength if enable_hr else None,
+ firstphase_width=firstphase_width if enable_hr else None,
+ firstphase_height=firstphase_height if enable_hr else None,
)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
@@ -52,4 +57,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info)
-
diff --git a/modules/ui.py b/modules/ui.py
index dad509f3..2311572c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,40 +5,54 @@ import json
import math
import mimetypes
import os
+import platform
import random
+import subprocess as sp
import sys
+import tempfile
import time
import traceback
-import platform
-import subprocess as sp
-from functools import reduce
+from functools import partial, reduce
+import gradio as gr
+import gradio.routes
+import gradio.utils
import numpy as np
+import piexif
import torch
from PIL import Image, PngImagePlugin
-import piexif
import gradio as gr
import gradio.utils
import gradio.routes
-from modules import sd_hijack
+from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path
-from modules.shared import opts, cmd_opts
+
+from modules.shared import opts, cmd_opts, restricted_opts
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
+
+import modules.codeformer_model
+import modules.generation_parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
+import modules.images_history as img_his
import modules.ldsr_model
import modules.scripts
-import modules.gfpgan_model
-import modules.codeformer_model
+import modules.shared as shared
import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui
+import modules.hypernetworks.ui
+
+import modules.images_history as img_his
+
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -50,6 +64,11 @@ if not cmd_opts.share and not cmd_opts.listen:
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
+if cmd_opts.ngrok != None:
+ import modules.ngrok as ngrok
+ print('ngrok authtoken detected, trying to connect...')
+ ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region)
+
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
@@ -72,6 +91,10 @@ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
+refresh_symbol = '\U0001f504' # 🔄
+save_style_symbol = '\U0001f4be' # 💾
+apply_style_symbol = '\U0001f4cb' # 📋
+
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@@ -79,6 +102,14 @@ def plaintext_to_html(text):
def image_from_url_text(filedata):
+ if type(filedata) == dict and filedata["is_file"]:
+ filename = filedata["name"]
+ tempdir = os.path.normpath(tempfile.gettempdir())
+ normfn = os.path.normpath(filename)
+ assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
+
+ return Image.open(filename)
+
if type(filedata) == list:
if len(filedata) == 0:
return None
@@ -101,7 +132,7 @@ def send_gradio_gallery_to_image(x):
def save_files(js_data, images, do_make_zip, index):
- import csv
+ import csv
filenames = []
fullfns = []
@@ -125,6 +156,8 @@ def save_files(js_data, images, do_make_zip, index):
images = [images[index]]
start_index = index
+ os.makedirs(opts.outdir_save, exist_ok=True)
+
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0
writer = csv.writer(file)
@@ -132,10 +165,7 @@ def save_files(js_data, images, do_make_zip, index):
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
for image_index, filedata in enumerate(images, start_index):
- if filedata.startswith("data:image/png;base64,"):
- filedata = filedata[len("data:image/png;base64,"):]
-
- image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
+ image = image_from_url_text(filedata)
is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image)
@@ -165,6 +195,23 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
+def save_pil_to_file(pil_image, dir=None):
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in pil_image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
+ pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
+ return file_obj
+
+
+# override save to file function so that it also writes PNG info
+gr.processing_utils.save_pil_to_file = save_pil_to_file
+
+
def wrap_gradio_call(func, extra_outputs=None):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
@@ -175,8 +222,15 @@ def wrap_gradio_call(func, extra_outputs=None):
try:
res = list(func(*args, **kwargs))
except Exception as e:
+ # When printing out our debug argument list, do not print out more than a MB of text
+ max_debug_str_len = 131072 # (1024*1024)/8
+
print("Error completing request", file=sys.stderr)
- print("Arguments:", args, kwargs, file=sys.stderr)
+ argStr = f"Arguments: {str(args)} {str(kwargs)}"
+ print(argStr[:max_debug_str_len], file=sys.stderr)
+ if len(argStr) > max_debug_str_len:
+ print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
+
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
@@ -218,6 +272,24 @@ def wrap_gradio_call(func, extra_outputs=None):
return f
+def calc_time_left(progress, threshold, label, force_display):
+ if progress == 0:
+ return ""
+ else:
+ time_since_start = time.time() - shared.state.time_start
+ eta = (time_since_start/progress)
+ eta_relative = eta-time_since_start
+ if (eta_relative > threshold and progress > 0.02) or force_display:
+ if eta_relative > 3600:
+ return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ elif eta_relative > 60:
+ return label + time.strftime('%M:%S', time.gmtime(eta_relative))
+ else:
+ return label + time.strftime('%Ss', time.gmtime(eta_relative))
+ else:
+ return ""
+
+
def check_progress_call(id_part):
if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False), gr_show(False)
@@ -229,11 +301,15 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+ time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
+ if time_left != "":
+ shared.state.time_left_force_display = True
+
progress = min(progress, 1)
progressbar = ""
if opts.show_progressbar:
- progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""
+ progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
image = gr_show(False)
preview_visibility = gr_show(False)
@@ -242,7 +318,10 @@ def check_progress_call(id_part):
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
+ if opts.show_progress_grid:
+ shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
+ else:
+ shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image
@@ -265,6 +344,8 @@ def check_progress_call_initial(id_part):
shared.state.current_latent = None
shared.state.current_image = None
shared.state.textinfo = None
+ shared.state.time_start = time.time()
+ shared.state.time_left_force_display = False
return check_progress_call(id_part)
@@ -286,7 +367,7 @@ def visit(x, func, path=""):
def add_style(name: str, prompt: str, negative_prompt: str):
if name is None:
- return [gr_show(), gr_show()]
+ return [gr_show() for x in range(4)]
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
shared.prompt_styles.styles[style.name] = style
@@ -411,27 +492,38 @@ def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
with gr.Row(elem_id="toprow"):
- with gr.Column(scale=4):
+ with gr.Column(scale=6):
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
+
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
- with gr.Column(scale=1, elem_id="roll_col"):
- roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ with gr.Column(scale=1, elem_id="roll_col"):
+ roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
+ paste = gr.Button(value=paste_symbol, elem_id="paste")
+ save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
+ prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- with gr.Column(scale=10, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+ token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- with gr.Row():
- with gr.Column(scale=8):
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+ button_interrogate = None
+ button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_id="interrogate_col"):
+ button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+ if cmd_opts.deepdanbooru:
+ button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1):
with gr.Row():
@@ -451,20 +543,16 @@ def create_toprow(is_img2img):
outputs=[],
)
- with gr.Row(scale=1):
- if is_img2img:
- interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- if cmd_opts.deepdanbooru:
- deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
- else:
- deepbooru = None
- else:
- interrogate = None
- deepbooru = None
- prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
- save_style = gr.Button('Create style', elem_id="style_create")
+ with gr.Row():
+ with gr.Column(scale=1, elem_id="style_pos_col"):
+ prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+ prompt_style.save_to_config = True
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ with gr.Column(scale=1, elem_id="style_neg_col"):
+ prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+ prompt_style2.save_to_config = True
+
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -488,13 +576,67 @@ def setup_progressbar(progressbar, preview, id_part, textinfo=None):
)
+def apply_setting(key, value):
+ if value is None:
+ return gr.update()
+
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
+ # dont allow model to be swapped when model hash exists in prompt
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
+ return gr.update()
+
+ if key == "sd_model_checkpoint":
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
+
+ if ckpt_info is not None:
+ value = ckpt_info.title
+ else:
+ return gr.update()
+
+ comp_args = opts.data_labels[key].component_args
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
+ return
+
+ valtype = type(opts.data_labels[key].default)
+ oldval = opts.data[key]
+ opts.data[key] = valtype(value) if valtype != type(None) else value
+ if oldval != value and opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+
+ opts.save(shared.config_filename)
+ return value
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
+
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1):
@@ -520,11 +662,12 @@ def create_ui(wrap_gradio_gpu_call):
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options:
- scale_latent = gr.Checkbox(label='Scale latent', value=False)
+ firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
+ firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
- with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ with gr.Row(equal_height=True):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
@@ -540,7 +683,7 @@ def create_ui(wrap_gradio_gpu_call):
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
- with gr.Group():
+ with gr.Column():
with gr.Row():
save = gr.Button('Save')
send_to_img2img = gr.Button('Send to img2img')
@@ -551,13 +694,13 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
+
with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -582,9 +725,11 @@ def create_ui(wrap_gradio_gpu_call):
height,
width,
enable_hr,
- scale_latent,
denoising_strength,
+ firstphase_width,
+ firstphase_height,
] + custom_inputs,
+
outputs=[
txt2img_gallery,
generation_info,
@@ -596,6 +741,17 @@ def create_ui(wrap_gradio_gpu_call):
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
+ txt_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ txt_prompt_img
+ ],
+ outputs=[
+ txt2img_prompt,
+ txt_prompt_img
+ ]
+ )
+
enable_hr.change(
fn=lambda x: gr_show(x),
inputs=[enable_hr],
@@ -648,14 +804,30 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (firstphase_width, "First pass size-1"),
+ (firstphase_height, "First pass size-2"),
+ *modules.scripts.scripts_txt2img.infotext_fields
]
- modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
+
+ txt2img_preview_params = [
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ steps,
+ sampler_index,
+ cfg_scale,
+ seed,
+ width,
+ height,
+ ]
+
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+
with gr.Column(scale=1):
pass
@@ -669,10 +841,10 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
with gr.TabItem('Inpaint', id='inpaint'):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
@@ -702,15 +874,15 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group():
@@ -728,7 +900,7 @@ def create_ui(wrap_gradio_gpu_call):
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
- with gr.Group():
+ with gr.Column():
with gr.Row():
save = gr.Button('Save')
img2img_send_to_img2img = gr.Button('Send to img2img')
@@ -739,17 +911,28 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
+
with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+ img2img_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ img2img_prompt_img
+ ],
+ outputs=[
+ img2img_prompt,
+ img2img_prompt_img
+ ]
+ )
+
mask_mode.change(
lambda mode, img: {
init_img_with_mask: gr_show(mode == 0),
@@ -889,8 +1072,8 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
+ *modules.scripts.scripts_img2img.infotext_fields
]
- modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as extras_interface:
@@ -903,13 +1086,30 @@ def create_ui(wrap_gradio_gpu_call):
with gr.TabItem('Batch Process'):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
- upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
+ with gr.TabItem('Batch from Directory'):
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
+ placeholder="A directory on the same machine where the server is running."
+ )
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
+ placeholder="Leave blank to save images to the default path."
+ )
+ show_extras_results = gr.Checkbox(label='Show result images', value=True)
+
+ with gr.Tabs(elem_id="extras_resize_mode"):
+ with gr.TabItem('Scale by'):
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
+ with gr.TabItem('Scale to'):
+ with gr.Group():
+ with gr.Row():
+ upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
+ upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
+ extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
- extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
+ extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
with gr.Group():
@@ -930,17 +1130,25 @@ def create_ui(wrap_gradio_gpu_call):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
+
submit.click(
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index",
inputs=[
dummy_component,
+ dummy_component,
extras_image,
image_batch,
+ extras_batch_input_dir,
+ extras_batch_output_dir,
+ show_extras_results,
gfpgan_visibility,
codeformer_visibility,
codeformer_weight,
upscaling_resize,
+ upscaling_resize_w,
+ upscaling_resize_h,
+ upscaling_crop,
extras_upscaler_1,
extras_upscaler_2,
extras_upscaler_2_visibility,
@@ -951,17 +1159,17 @@ def create_ui(wrap_gradio_gpu_call):
html_info,
]
)
-
+
extras_send_to_img2img.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery_img2img",
inputs=[result_images],
outputs=[init_img],
)
-
+
extras_send_to_inpaint.click(
fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
+ _js="extract_image_from_gallery_inpaint",
inputs=[result_images],
outputs=[init_img_with_mask],
)
@@ -985,6 +1193,14 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[image],
outputs=[html, generation_info, html2],
)
+ #images history
+ images_history_switch_dict = {
+ "fn": modules.generation_parameters_copypaste.connect_paste,
+ "t2i": txt2img_paste_fields,
+ "i2i": img2img_paste_fields
+ }
+
+ images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
@@ -992,11 +1208,12 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row():
- primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
- secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
+ primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
custom_name = gr.Textbox(label="Custom Name (Optional)")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
- interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
+ interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
@@ -1005,35 +1222,58 @@ def create_ui(wrap_gradio_gpu_call):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
- with gr.Blocks() as textual_inversion_interface:
+ with gr.Blocks() as train_interface:
with gr.Row().style(equal_height=False):
- with gr.Column():
- with gr.Group():
- gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
+ with gr.Row().style(equal_height=False):
+ with gr.Tabs(elem_id="train_tabs"):
+ with gr.Tab(label="Create embedding"):
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
- create_embedding = gr.Button(value="Create", variant='primary')
+ create_embedding = gr.Button(value="Create embedding", variant='primary')
- with gr.Group():
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
+ with gr.Tab(label="Create hypernetwork"):
+ new_hypernetwork_name = gr.Textbox(label="Name")
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
+ with gr.Tab(label="Preprocess images"):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
+ process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
- process_split = gr.Checkbox(label='Split oversized images into two')
- process_caption = gr.Checkbox(label='Use BLIP caption as filename')
+ process_split = gr.Checkbox(label='Split oversized images')
+ process_caption = gr.Checkbox(label='Use BLIP for caption')
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
+
+ with gr.Row(visible=False) as process_split_extra_row:
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
with gr.Row():
with gr.Column(scale=3):
@@ -1042,25 +1282,40 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary')
- with gr.Group():
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
- train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
- learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
+ process_split.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_split],
+ outputs=[process_split_extra_row],
+ )
+
+ with gr.Tab(label="Train"):
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
+ with gr.Row():
+ train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
+ create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
+ with gr.Row():
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
+ with gr.Row():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
+
+ batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
- with gr.Column(scale=2):
- gr.HTML(value="")
-
- with gr.Column():
- with gr.Row():
- interrupt_training = gr.Button(value="Interrupt")
- train_embedding = gr.Button(value="Train", variant='primary')
+ interrupt_training = gr.Button(value="Interrupt")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
+ train_embedding = gr.Button(value="Train Embedding", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
@@ -1078,6 +1333,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name,
initialization_text,
nvpt,
+ overwrite_old_embedding,
],
outputs=[
train_embedding_name,
@@ -1086,15 +1342,39 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ create_hypernetwork.click(
+ fn=modules.hypernetworks.ui.create_hypernetwork,
+ inputs=[
+ new_hypernetwork_name,
+ new_hypernetwork_sizes,
+ overwrite_old_hypernetwork,
+ new_hypernetwork_layer_structure,
+ new_hypernetwork_activation_func,
+ new_hypernetwork_add_layer_norm,
+ new_hypernetwork_use_dropout
+ ],
+ outputs=[
+ train_hypernetwork_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
run_preprocess.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
process_src,
process_dst,
+ process_width,
+ process_height,
+ preprocess_txt_action,
process_flip,
process_split,
process_caption,
+ process_caption_deepbooru,
+ process_split_threshold,
+ process_overlap_ratio,
],
outputs=[
ti_output,
@@ -1107,13 +1387,43 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
- learn_rate,
+ embedding_learn_rate,
+ batch_size,
dataset_directory,
log_directory,
+ training_width,
+ training_height,
steps,
create_image_every,
save_embedding_every,
template_file,
+ save_image_with_stored_embedding,
+ preview_from_txt2img,
+ *txt2img_preview_params,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ train_hypernetwork.click(
+ fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_hypernetwork_name,
+ hypernetwork_learn_rate,
+ batch_size,
+ dataset_directory,
+ log_directory,
+ training_width,
+ training_height,
+ steps,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ preview_from_txt2img,
+ *txt2img_preview_params,
],
outputs=[
ti_output,
@@ -1127,7 +1437,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[],
)
- def create_setting_component(key):
+ def create_setting_component(key, is_quicksettings=False):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1147,13 +1457,33 @@ def create_ui(wrap_gradio_gpu_call):
else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
- return comp(label=info.label, value=fun, **(args or {}))
+ elem_id = "setting_"+key
+
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
+ else:
+ with gr.Row(variant="compact"):
+ res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
+ else:
+ res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+
+
+ return res
components = []
component_dict = {}
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
def open_folder(f):
- if not os.path.isdir(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
print(f"""
WARNING
An open_folder request was made with an argument that is not a folder.
@@ -1174,15 +1504,23 @@ Requested path was: {f}
def run_settings(*args):
changed = 0
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if not opts.same_type(value, opts.data_labels[key].default):
- return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
+ if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
+ return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
for key, value, comp in zip(opts.data_labels.keys(), args, components):
+ if comp == dummy_component:
+ continue
+
comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue
+ if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ continue
+
oldval = opts.data.get(key, None)
opts.data[key] = value
@@ -1196,6 +1534,26 @@ Requested path was: {f}
return f'{changed} settings changed.', opts.dumpjson()
+ def run_settings_single(value, key):
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
+ if not opts.same_type(value, opts.data_labels[key].default):
+ return gr.update(visible=True), opts.dumpjson()
+
+ oldval = opts.data.get(key, None)
+ if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ return gr.update(value=oldval), opts.dumpjson()
+
+ opts.data[key] = value
+
+ if oldval != value:
+ if opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+
+ opts.save(shared.config_filename)
+
+ return gr.update(value=value), opts.dumpjson()
+
with gr.Blocks(analytics_enabled=False) as settings_interface:
settings_submit = gr.Button(value="Apply settings", variant='primary')
result = gr.HTML()
@@ -1203,6 +1561,11 @@ Requested path was: {f}
settings_cols = 3
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+ quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
+ quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+
+ quicksettings_list = []
+
cols_displayed = 0
items_displayed = 0
previous_section = None
@@ -1223,14 +1586,26 @@ Requested path was: {f}
previous_section = item.section
- gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
+ elem_id, text = item.section
+ gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='<h1 class="gr-button-lg">{}</h1>'.format(text))
- component = create_setting_component(k)
- component_dict[k] = component
- components.append(component)
- items_displayed += 1
+ if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
+ quicksettings_list.append((i, k, item))
+ components.append(dummy_component)
+ else:
+ component = create_setting_component(k)
+ component_dict[k] = component
+ components.append(component)
+ items_displayed += 1
+
+ with gr.Row():
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+
+ with gr.Row():
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
+ restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
request_notifications.click(
fn=lambda: None,
inputs=[],
@@ -1238,13 +1613,16 @@ Requested path was: {f}
_js='function(){}'
)
- with gr.Row():
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
- restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
-
+ download_localization.click(
+ fn=lambda: None,
+ inputs=[],
+ outputs=[],
+ _js='download_localization'
+ )
def reload_scripts():
modules.scripts.reload_script_body_only()
+ reload_javascript() # need to refresh the html page
reload_script_bodies.click(
fn=reload_scripts,
@@ -1263,7 +1641,7 @@ Requested path was: {f}
outputs=[],
_js='function(){restart_reload()}'
)
-
+
if column is not None:
column.__exit__()
@@ -1272,31 +1650,44 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
+ (images_history, "Image Browser", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
- (textual_inversion_interface, "Textual inversion", "ti"),
- (settings_interface, "Settings", "settings"),
+ (train_interface, "Train", "ti"),
]
- with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
- css = file.read()
+ interfaces += script_callbacks.ui_tabs_callback()
+
+ interfaces += [(settings_interface, "Settings", "settings")]
+
+ css = ""
+
+ for cssfile in modules.scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
+ with open(cssfile, "r", encoding="utf8") as file:
+ css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
- usercss = file.read()
- css += usercss
+ css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
-
+ with gr.Row(elem_id="quicksettings"):
+ for i, k, item in quicksettings_list:
+ component = create_setting_component(k, is_quicksettings=True)
+ component_dict[k] = component
+
settings_interface.gradio_ref = demo
-
- with gr.Tabs() as tabs:
+
+ with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
- with gr.TabItem(label, id=ifid):
+ with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
-
+
if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
@@ -1306,7 +1697,16 @@ Requested path was: {f}
inputs=components,
outputs=[result, text_settings],
)
-
+
+ for i, k, item in quicksettings_list:
+ component = component_dict[k]
+
+ component.change(
+ fn=lambda value, k=k: run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, text_settings],
+ )
+
def modelmerger(*args):
try:
results = modules.extras.run_modelmerger(*args)
@@ -1322,6 +1722,7 @@ Requested path was: {f}
inputs=[
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
interp_method,
interp_amount,
save_as_half,
@@ -1331,6 +1732,7 @@ Requested path was: {f}
submit_result,
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
component_dict['sd_model_checkpoint'],
]
)
@@ -1397,8 +1799,22 @@ Requested path was: {f}
outputs=[extras_image],
)
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields, generation_info, 'switch_to_txt2img')
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields, generation_info, 'switch_to_img2img_img2img')
+ settings_map = {
+ 'sd_hypernetwork': 'Hypernet',
+ 'CLIP_stop_at_last_layers': 'Clip skip',
+ 'sd_model_checkpoint': 'Model hash',
+ }
+
+ settings_paste_fields = [
+ (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
+ for k, v in settings_map.items()
+ ]
+
+ modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
+ modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
+
+ modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
+ modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
@@ -1415,20 +1831,24 @@ Requested path was: {f}
print(traceback.format_exc(), file=sys.stderr)
def loadsave(path, x):
- def apply_field(obj, field, condition=None):
+ def apply_field(obj, field, condition=None, init_field=None):
key = path + "/" + field
if getattr(obj,'custom_script_source',None) is not None:
key = 'customscript/' + obj.custom_script_source + '/' + key
-
+
if getattr(obj, 'do_not_save_to_config', False):
return
-
+
saved_value = ui_settings.get(key, None)
if saved_value is None:
ui_settings[key] = getattr(obj, field)
- elif condition is None or condition(saved_value):
+ elif condition and not condition(saved_value):
+ print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
+ else:
setattr(obj, field, saved_value)
+ if init_field is not None:
+ init_field(saved_value)
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
apply_field(x, 'visible')
@@ -1447,13 +1867,20 @@ Requested path was: {f}
if type(x) == gr.Textbox:
apply_field(x, 'value')
-
+
if type(x) == gr.Number:
apply_field(x, 'value')
-
+
+ # Since there are many dropdowns that shouldn't be saved,
+ # we only mark dropdowns that should be saved.
+ if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
+ apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
+ apply_field(x, 'visible')
+
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras")
+ visit(modelmerger_interface, loadsave, "modelmerger")
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
with open(ui_config_file, "w", encoding="utf8") as file:
@@ -1462,21 +1889,29 @@ Requested path was: {f}
return demo
-with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
- javascript = f'<script>{jsfile.read()}</script>'
+def load_javascript(raw_response):
+ with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
+ javascript = f'<script>{jsfile.read()}</script>'
-jsdir = os.path.join(script_path, "javascript")
-for filename in sorted(os.listdir(jsdir)):
- with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
- javascript += f"\n<script>{jsfile.read()}</script>"
+ scripts_list = modules.scripts.list_scripts("javascript", ".js")
+ for basedir, filename, path in scripts_list:
+ with open(path, "r", encoding="utf8") as jsfile:
+ javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
+ if cmd_opts.theme is not None:
+ javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
+
+ javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
-if 'gradio_routes_templates_response' not in globals():
def template_response(*args, **kwargs):
- res = gradio_routes_templates_response(*args, **kwargs)
- res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
+ res = raw_response(*args, **kwargs)
+ res.body = res.body.replace(
+ b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
return res
- gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio.routes.templates.TemplateResponse = template_response
+
+
+reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
+reload_javascript()