diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 4 | ||||
-rw-r--r-- | modules/api/models.py | 1 | ||||
-rw-r--r-- | modules/extras.py | 30 | ||||
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 4 | ||||
-rw-r--r-- | modules/img2img.py | 58 | ||||
-rw-r--r-- | modules/prompt_parser.py | 4 | ||||
-rw-r--r-- | modules/scripts.py | 6 | ||||
-rw-r--r-- | modules/sd_disable_initialization.py | 71 | ||||
-rw-r--r-- | modules/sd_models.py | 5 | ||||
-rw-r--r-- | modules/shared.py | 8 | ||||
-rw-r--r-- | modules/textual_inversion/preprocess.py | 7 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 16 | ||||
-rw-r--r-- | modules/ui.py | 112 | ||||
-rw-r--r-- | modules/ui_components.py | 6 |
14 files changed, 191 insertions, 141 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 6c564ad8..5767ba90 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -286,7 +286,7 @@ class Api: # copy from check_progress_call of ui.py if shared.state.job_count == 0: - return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict()) + return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) # avoid dividing zero progress = 0.01 @@ -308,7 +308,7 @@ class Api: if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) + return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) def interrogateapi(self, interrogatereq: InterrogateRequest): image_b64 = interrogatereq.image diff --git a/modules/api/models.py b/modules/api/models.py index 034b4aa0..c78095ca 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -168,6 +168,7 @@ class ProgressResponse(BaseModel): eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") + textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") class InterrogateRequest(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") diff --git a/modules/extras.py b/modules/extras.py index 7407bfe3..a03d558e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -3,6 +3,7 @@ import math import os
import sys
import traceback
+import shutil
import numpy as np
from PIL import Image
@@ -248,7 +249,32 @@ def run_pnginfo(image): return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+def create_config(ckpt_result, config_source, a, b, c):
+ def config(x):
+ return sd_models.find_checkpoint_config(x) if x else None
+
+ if config_source == 0:
+ cfg = config(a) or config(b) or config(c)
+ elif config_source == 1:
+ cfg = config(b)
+ elif config_source == 2:
+ cfg = config(c)
+ else:
+ cfg = None
+
+ if cfg is None:
+ return
+
+ filename, _ = os.path.splitext(ckpt_result)
+ checkpoint_filename = filename + ".yaml"
+
+ print("Copying config:")
+ print(" from:", cfg)
+ print(" to:", checkpoint_filename)
+ shutil.copyfile(cfg, checkpoint_filename)
+
+
+def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam sd_models.list_models()
+ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
+
print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 300d3975..194679e8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -619,7 +619,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ shared.state.textinfo = description
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
diff --git a/modules/img2img.py b/modules/img2img.py index ca58b5d8..f62783c6 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,38 +59,34 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
- is_inpaint = mode == 1
- is_batch = mode == 2
-
- if is_inpaint:
- # Drawn mask
- if mask_mode == 0:
- is_mask_sketch = isinstance(init_img_with_mask, dict)
- is_mask_paint = not is_mask_sketch
- if is_mask_sketch:
- # Sketch: mask iff. not transparent
- image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
- else:
- # Color-sketch: mask iff. painted over
- image = init_img_with_mask
- orig = init_img_with_mask_orig or init_img_with_mask
- pred = np.any(np.array(image) != np.array(orig), axis=-1)
- mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
- mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
- blur = ImageFilter.GaussianBlur(mask_blur)
- image = Image.composite(image.filter(blur), orig, mask.filter(blur))
-
- image = image.convert("RGB")
- # Uploaded mask
- else:
- image = init_img_inpaint
- mask = init_mask_inpaint
- # No mask
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+ is_batch = mode == 5
+
+ if mode == 0: # img2img
+ image = init_img.convert("RGB")
+ mask = None
+ elif mode == 1: # img2img sketch
+ image = sketch.convert("RGB")
+ mask = None
+ elif mode == 2: # inpaint
+ image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+ image = image.convert("RGB")
+ elif mode == 3: # inpaint sketch
+ image = inpaint_color_sketch
+ orig = inpaint_color_sketch_orig or inpaint_color_sketch
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
+ blur = ImageFilter.GaussianBlur(mask_blur)
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
+ image = image.convert("RGB")
+ elif mode == 4: # inpaint upload mask
+ image = init_img_inpaint
+ mask = init_mask_inpaint
else:
- image = init_img
+ image = None
mask = None
# Use the EXIF orientation of photos taken by smartphones.
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f70872c4..870218db 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -49,6 +49,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): [[5, 'a c'], [10, 'a {b|d{ c']]
>>> g("((a][:b:c [d:3]")
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
+ >>> g("[a|(b:1.1)]")
+ [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
"""
def collect_steps(steps, tree):
@@ -84,7 +86,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): yield args[0].value
def __default__(self, data, children, meta):
for child in children:
- yield from child
+ yield child
return AtStep().transform(tree)
def get_schedule(prompt):
diff --git a/modules/scripts.py b/modules/scripts.py index 35164093..4ffc369b 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -152,7 +152,7 @@ def basedir(): scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
-ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def list_scripts(scriptdirname, extension):
@@ -206,7 +206,7 @@ def load_scripts(): for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
except Exception:
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
@@ -241,7 +241,7 @@ class ScriptRunner: self.alwayson_scripts.clear()
self.selectable_scripts.clear()
- for script_class, path, basedir in scripts_data:
+ for script_class, path, basedir, script_module in scripts_data:
script = script_class()
script.filename = path
script.is_txt2img = not is_img2img
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 088ac24b..c72d8efc 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -20,6 +20,19 @@ class DisableInitialization: ```
"""
+ def __init__(self):
+ self.replaced = []
+
+ def replace(self, obj, field, func):
+ original = getattr(obj, field, None)
+ if original is None:
+ return None
+
+ self.replaced.append((obj, field, original))
+ setattr(obj, field, func)
+
+ return original
+
def __enter__(self):
def do_nothing(*args, **kwargs):
pass
@@ -37,11 +50,14 @@ class DisableInitialization: def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
# this file is always 404, prevent making request
- if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json':
- raise transformers.utils.hub.EntryNotFoundError
+ if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
+ return None
try:
- return original(url, *args, local_files_only=True, **kwargs)
+ res = original(url, *args, local_files_only=True, **kwargs)
+ if res is None:
+ res = original(url, *args, local_files_only=False, **kwargs)
+ return res
except Exception as e:
return original(url, *args, local_files_only=False, **kwargs)
@@ -54,42 +70,19 @@ class DisableInitialization: def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
- self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
- self.init_no_grad_normal = torch.nn.init._no_grad_normal_
- self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_
- self.create_model_and_transforms = open_clip.create_model_and_transforms
- self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
- self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None)
- self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None)
- self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None)
- self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None)
-
- torch.nn.init.kaiming_uniform_ = do_nothing
- torch.nn.init._no_grad_normal_ = do_nothing
- torch.nn.init._no_grad_uniform_ = do_nothing
- open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
- ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
- if self.transformers_modeling_utils_load_pretrained_model is not None:
- transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model
- if self.transformers_tokenization_utils_base_cached_file is not None:
- transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file
- if self.transformers_configuration_utils_cached_file is not None:
- transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file
- if self.transformers_utils_hub_get_from_cache is not None:
- transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache
+ self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
+ self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
+ self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
- torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
- torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
- torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_
- open_clip.create_model_and_transforms = self.create_model_and_transforms
- ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
- if self.transformers_modeling_utils_load_pretrained_model is not None:
- transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model
- if self.transformers_tokenization_utils_base_cached_file is not None:
- transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file
- if self.transformers_configuration_utils_cached_file is not None:
- transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file
- if self.transformers_utils_hub_get_from_cache is not None:
- transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache
+ for obj, field, original in self.replaced:
+ setattr(obj, field, original)
+
+ self.replaced.clear()
diff --git a/modules/sd_models.py b/modules/sd_models.py index b5bc12f0..c466f273 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -333,10 +333,15 @@ def load_model(checkpoint_info=None): timer = Timer()
+ sd_model = None
+
try:
with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
+ pass
+
+ if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model)
diff --git a/modules/shared.py b/modules/shared.py index aa37c8ce..1c964237 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading, errors
+from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
from modules.paths import models_path, script_path, sd_path
@@ -74,8 +74,8 @@ parser.add_argument("--freeze-settings", action='store_true', help="disable edit parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
-parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
-parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
+parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
+parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
@@ -387,7 +387,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
- "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
+ "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index feb876c6..3c1042ad 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -135,7 +135,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre params.process_caption_deepbooru = process_caption_deepbooru
params.preprocess_txt_action = preprocess_txt_action
- for index, imagefile in enumerate(tqdm.tqdm(files)):
+ pbar = tqdm.tqdm(files)
+ for index, imagefile in enumerate(pbar):
params.subindex = 0
filename = os.path.join(src, imagefile)
try:
@@ -143,6 +144,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre except Exception:
continue
+ description = f"Preprocessing [Image {index}/{len(files)}]"
+ pbar.set_description(description)
+ shared.state.textinfo = description
+
params.src = filename
existing_caption = None
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5420903f..853246a6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -9,6 +9,7 @@ import tqdm import html
import datetime
import csv
+import safetensors.torch
from PIL import Image, PngImagePlugin
@@ -150,6 +151,8 @@ class EmbeddingDatabase: name = data.get('name', name)
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ elif ext in ['.SAFETENSORS']:
+ data = safetensors.torch.load_file(path, device="cpu")
else:
return
@@ -245,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
- embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
+ #cond_model expects at least some text, so we provide '*' as backup.
+ embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
- for i in range(num_vectors_per_token):
- vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ #Only copy if we provided an init_text, otherwise keep vectors as zeros
+ if init_text:
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -473,7 +479,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
diff --git a/modules/ui.py b/modules/ui.py index 3c458ce8..e86a624b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -795,53 +795,67 @@ def create_ui(): with FormRow().style(equal_height=False):
with gr.Column(variant='panel', elem_id="img2img_settings"):
+ with gr.Tabs(elem_id="mode_img2img"):
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
- with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
- with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
+ with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
- init_img_with_mask_orig = gr.State(None)
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
- use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
- if use_color_sketch:
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
+ with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ inpaint_color_sketch_orig = gr.State(None)
- init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
+ inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
-
- with FormRow():
- mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
-
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+ with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
-
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
+ with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ def select_img2img_tab(tab):
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
+
+ for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
+ elem.select(
+ fn=lambda tab=i: select_img2img_tab(tab),
+ inputs=[],
+ outputs=[inpaint_controls, mask_alpha],
+ )
+
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
@@ -900,20 +914,6 @@ def create_ui(): ]
)
- mask_mode.change(
- lambda mode, img: {
- init_img_with_mask: gr_show(mode == 0),
- init_img_inpaint: gr_show(mode == 1),
- init_mask_inpaint: gr_show(mode == 1),
- },
- inputs=[mask_mode, init_img_with_mask],
- outputs=[
- init_img_with_mask,
- init_img_inpaint,
- init_mask_inpaint,
- ],
- )
-
img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
@@ -924,11 +924,12 @@ def create_ui(): img2img_prompt_style,
img2img_prompt_style2,
init_img,
+ sketch,
init_img_with_mask,
- init_img_with_mask_orig,
+ inpaint_color_sketch,
+ inpaint_color_sketch_orig,
init_img_inpaint,
init_mask_inpaint,
- mask_mode,
steps,
sampler_index,
mask_blur,
@@ -1129,7 +1130,7 @@ def create_ui(): with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
- with gr.Row():
+ with FormRow():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@@ -1143,11 +1144,13 @@ def create_ui(): interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|