aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py44
-rw-r--r--modules/api/models.py10
-rw-r--r--modules/extras.py2
-rw-r--r--modules/generation_parameters_copypaste.py101
-rw-r--r--modules/images.py36
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/interrogate.py2
-rw-r--r--modules/memmon.py3
-rw-r--r--modules/modelloader.py20
-rw-r--r--modules/processing.py73
-rw-r--r--modules/script_callbacks.py24
-rw-r--r--modules/sd_hijack.py25
-rw-r--r--modules/sd_hijack_clip.py10
-rw-r--r--modules/sd_hijack_inpainting.py232
-rw-r--r--modules/sd_hijack_xlmr.py34
-rw-r--r--modules/sd_models.py5
-rw-r--r--modules/sd_samplers.py4
-rw-r--r--modules/sd_vae.py31
-rw-r--r--modules/shared.py32
-rw-r--r--modules/textual_inversion/textual_inversion.py32
-rw-r--r--modules/txt2img.py8
-rw-r--r--modules/ui.py578
-rw-r--r--modules/ui_components.py25
-rw-r--r--modules/ui_tempdir.py24
-rw-r--r--modules/upscaler.py6
-rw-r--r--modules/xlmr.py137
26 files changed, 871 insertions, 629 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 1ceba75d..9c670f00 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -100,6 +100,7 @@ class Api:
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
@@ -121,7 +122,6 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
populate = txt2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
@@ -129,15 +129,14 @@ class Api:
)
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
- p = StableDiffusionProcessingTxt2Img(**vars(populate))
- # Override object param
-
- shared.state.begin()
with self.queue_lock:
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
+
+ shared.state.begin()
processed = process_images(p)
+ shared.state.end()
- shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -153,7 +152,6 @@ class Api:
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
@@ -165,16 +163,14 @@ class Api:
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
- p = StableDiffusionProcessingImg2Img(**args)
-
- p.init_images = [decode_base64_to_image(x) for x in init_images]
-
- shared.state.begin()
with self.queue_lock:
- processed = process_images(p)
+ p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
- shared.state.end()
+ shared.state.begin()
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -332,6 +328,26 @@ class Api:
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
+ def get_embeddings(self):
+ db = sd_hijack.model_hijack.embedding_db
+
+ def convert_embedding(embedding):
+ return {
+ "step": embedding.step,
+ "sd_checkpoint": embedding.sd_checkpoint,
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
+ "shape": embedding.shape,
+ "vectors": embedding.vectors,
+ }
+
+ def convert_embeddings(embeddings):
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
+ return {
+ "loaded": convert_embeddings(db.word_embeddings),
+ "skipped": convert_embeddings(db.skipped_embeddings),
+ }
+
def refresh_checkpoints(self):
shared.refresh_checkpoints()
diff --git a/modules/api/models.py b/modules/api/models.py
index c446ce7a..4a632c68 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,3 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score")
category: str = Field(title="Category")
+class EmbeddingItem(BaseModel):
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
+class EmbeddingsResponse(BaseModel):
+ loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file
diff --git a/modules/extras.py b/modules/extras.py
index 68939dea..5e270250 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -303,6 +303,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else:
+ assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
+
theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half:
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index fbd91300..4baf4d9a 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,12 +1,13 @@
import base64
import io
+import math
import os
import re
from pathlib import Path
import gradio as gr
from modules.shared import script_path
-from modules import shared
+from modules import shared, ui_tempdir
import tempfile
from PIL import Image
@@ -36,9 +37,12 @@ def quote(text):
def image_from_url_text(filedata):
- if type(filedata) == dict and filedata["is_file"]:
+ if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+ filedata = filedata[0]
+
+ if type(filedata) == dict and filedata.get("is_file", False):
filename = filedata["name"]
- is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs)
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename)
@@ -93,7 +97,7 @@ def integrate_settings_paste_fields(component_dict):
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
- buttons[tab] = gr.Button(f"Send to {tab}")
+ buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
return buttons
@@ -102,35 +106,57 @@ def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info])
+def send_image_and_dimensions(x):
+ if isinstance(x, Image.Image):
+ img = x
+ else:
+ img = image_from_url_text(x)
+
+ if shared.opts.send_size and isinstance(img, Image.Image):
+ w = img.width
+ h = img.height
+ else:
+ w = gr.update()
+ h = gr.update()
+
+ return img, w, h
+
+
def run_bind():
- for buttons, send_image, send_generate_info in bind_list:
+ for buttons, source_image_component, send_generate_info in bind_list:
for tab in buttons:
button = buttons[tab]
- if send_image and paste_fields[tab]["init_img"]:
- if type(send_image) == gr.Gallery:
- button.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery",
- inputs=[send_image],
- outputs=[paste_fields[tab]["init_img"]],
- )
+ destination_image_component = paste_fields[tab]["init_img"]
+ fields = paste_fields[tab]["fields"]
+
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
+
+ if source_image_component and destination_image_component:
+ if isinstance(source_image_component, gr.Gallery):
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
+ jsfunc = "extract_image_from_gallery"
else:
- button.click(
- fn=lambda x: x,
- inputs=[send_image],
- outputs=[paste_fields[tab]["init_img"]],
- )
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
+ jsfunc = None
+
+ button.click(
+ fn=func,
+ _js=jsfunc,
+ inputs=[source_image_component],
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+ )
- if send_generate_info and paste_fields[tab]["fields"] is not None:
+ if send_generate_info and fields is not None:
if send_generate_info in paste_fields:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else [])
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
- outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in fields if name in paste_field_names],
)
else:
- connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
+ connect_paste(button, fields, send_generate_info)
button.click(
fn=None,
@@ -164,6 +190,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
return None
+def restore_old_hires_fix_params(res):
+ """for infotexts that specify old First pass size parameter, convert it into
+ width, height, and hr scale"""
+
+ firstpass_width = res.get('First pass size-1', None)
+ firstpass_height = res.get('First pass size-2', None)
+
+ if firstpass_width is None or firstpass_height is None:
+ return
+
+ firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
+ width = int(res.get("Size-1", 512))
+ height = int(res.get("Size-2", 512))
+
+ if firstpass_width == 0 or firstpass_height == 0:
+ # old algorithm for auto-calculating first pass size
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = width * height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ firstpass_width = math.ceil(scale * width / 64) * 64
+ firstpass_height = math.ceil(scale * height / 64) * 64
+
+ hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
+
+ res['Size-1'] = firstpass_width
+ res['Size-2'] = firstpass_height
+ res['Hires upscale'] = hr_scale
+
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -221,6 +276,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+ restore_old_hires_fix_params(res)
+
return res
diff --git a/modules/images.py b/modules/images.py
index 31d4528d..c3a5fc8b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
cols = math.ceil(len(imgs) / rows)
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
+ script_callbacks.image_grid_callback(params)
+
w, h = imgs[0].size
- grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
- for i, img in enumerate(imgs):
- grid.paste(img, box=(i % cols * w, i // cols * h))
+ for i, img in enumerate(params.imgs):
+ grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
return grid
@@ -227,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
-def resize_image(resize_mode, im, width, height):
+def resize_image(resize_mode, im, width, height, upscaler_name=None):
+ """
+ Resizes an image with the specified resize_mode, width, and height.
+
+ Args:
+ resize_mode: The mode to use when resizing the image.
+ 0: Resize the image to the specified width and height.
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
+ im: The image to resize.
+ width: The width to resize the image to.
+ height: The height to resize the image to.
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
+ """
+
+ upscaler_name = upscaler_name or opts.upscaler_for_img2img
+
def resize(im, w, h):
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height)
if scale > 1.0:
- upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
- assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
+ assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
@@ -525,6 +544,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
+ if image_to_save.mode == 'RGBA':
+ image_to_save = image_to_save.convert("RGB")
+
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None:
diff --git a/modules/img2img.py b/modules/img2img.py
index 81da4b13..ca58b5d8 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 46935210..6f761c5a 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -135,7 +135,7 @@ class InterrogateModels:
return caption[0]
def interrogate(self, pil_image):
- res = None
+ res = ""
try:
diff --git a/modules/memmon.py b/modules/memmon.py
index 9fb9b687..a7060f58 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
def read(self):
if not self.disabled:
free, total = torch.cuda.mem_get_info()
+ self.data["free"] = free
self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device)
+ self.data["active"] = torch_stats["active.all.current"]
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
+ self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"]
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e647f6fa..6a1a7ac8 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass
+builtin_upscaler_classes = []
+forbidden_upscaler_classes = set()
+
+
+def list_builtin_upscalers():
+ load_upscalers()
+
+ builtin_upscaler_classes.clear()
+ builtin_upscaler_classes.extend(Upscaler.__subclasses__())
+
+
+def forbid_loaded_nonbuiltin_upscalers():
+ for cls in Upscaler.__subclasses__():
+ if cls not in builtin_upscaler_classes:
+ forbidden_upscaler_classes.add(cls)
+
+
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -139,6 +156,9 @@ def load_upscalers():
datas = []
commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
+ if cls in forbidden_upscaler_classes:
+ continue
+
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None))
diff --git a/modules/processing.py b/modules/processing.py
index 0a9a8f95..a172af0b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -239,7 +239,7 @@ class StableDiffusionProcessing():
class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@@ -247,6 +247,7 @@ class Processed:
self.subseed = subseed
self.subseed_strength = p.subseed_strength
self.info = info
+ self.comments = comments
self.width = p.width
self.height = p.height
self.sampler_name = p.sampler_name
@@ -646,7 +647,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
- res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
+ res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None:
p.scripts.postprocess(p, res)
@@ -657,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
- self.firstphase_width = firstphase_width
- self.firstphase_height = firstphase_height
- self.truncate_x = 0
- self.truncate_y = 0
+ self.hr_scale = hr_scale
+ self.hr_upscaler = hr_upscaler
+
+ if firstphase_width != 0 or firstphase_height != 0:
+ print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
+ self.hr_scale = self.width / firstphase_width
+ self.width = firstphase_width
+ self.height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -673,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
-
- if self.firstphase_width == 0 or self.firstphase_height == 0:
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- firstphase_width_truncated = int(scale * self.width)
- firstphase_height_truncated = int(scale * self.height)
-
- else:
-
- width_ratio = self.width / self.firstphase_width
- height_ratio = self.height / self.firstphase_height
-
- if width_ratio > height_ratio:
- firstphase_width_truncated = self.firstphase_width
- firstphase_height_truncated = self.firstphase_width * self.height / self.width
- else:
- firstphase_width_truncated = self.firstphase_height * self.width / self.height
- firstphase_height_truncated = self.firstphase_height
-
- self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
- self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
+ if self.hr_upscaler is not None:
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+ if self.enable_hr and latent_scale_mode is None:
+ assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
+
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+
if not self.enable_hr:
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
- x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
-
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ target_width = int(self.width * self.hr_scale)
+ target_height = int(self.height * self.hr_scale)
- """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
def save_intermediate(image, index):
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
+
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
@@ -722,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
- if opts.use_scale_latent_for_hires_fix:
+ if latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
@@ -746,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
save_intermediate(image, i)
- image = images.resize_image(0, image, self.width, self.height)
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
batch_images.append(image)
@@ -763,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 8e22f875..de69fd9f 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -51,6 +51,13 @@ class UiTrainTabParams:
self.txt2img_preview_params = txt2img_preview_params
+class ImageGridLoopParams:
+ def __init__(self, imgs, cols, rows):
+ self.imgs = imgs
+ self.cols = cols
+ self.rows = rows
+
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
callbacks_app_started=[],
@@ -63,6 +70,7 @@ callback_map = dict(
callbacks_cfg_denoiser=[],
callbacks_before_component=[],
callbacks_after_component=[],
+ callbacks_image_grid=[],
)
@@ -155,6 +163,14 @@ def after_component_callback(component, **kwargs):
report_exception(c, 'after_component_callback')
+def image_grid_callback(params: ImageGridLoopParams):
+ for c in callback_map['callbacks_image_grid']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'image_grid')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -255,3 +271,11 @@ def on_before_component(callback):
def on_after_component(callback):
"""register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback)
+
+
+def on_image_grid(callback):
+ """register a function to be called before making an image grid.
+ The callback is called with one argument:
+ - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
+ """
+ add_callback(callback_map['callbacks_image_grid'], callback)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 019a6f3f..55a684cc 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -65,6 +65,7 @@ def fix_checkpoint():
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
+
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -75,17 +76,25 @@ class StableDiffusionModelHijack:
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
- if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
+