aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/esrgan_model.py2
-rw-r--r--modules/hypernetwork.py2
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/processing.py7
-rw-r--r--modules/prompt_parser.py9
-rw-r--r--modules/sd_hijack.py59
-rw-r--r--modules/sd_hijack_optimizations.py60
-rw-r--r--modules/sd_models.py6
-rw-r--r--modules/sd_samplers.py39
-rw-r--r--modules/shared.py9
-rw-r--r--modules/ui.py10
11 files changed, 159 insertions, 46 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index d17e730f..28548124 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -111,7 +111,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
+ pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py
index c7b86682..7f062242 100644
--- a/modules/hypernetwork.py
+++ b/modules/hypernetwork.py
@@ -43,7 +43,7 @@ class Hypernetwork:
def load_hypernetworks(path):
res = {}
- for filename in glob.iglob(path + '**/*.pt', recursive=True):
+ for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
try:
hn = Hypernetwork(filename)
res[hn.name] = hn
diff --git a/modules/img2img.py b/modules/img2img.py
index da212d72..24126774 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -32,6 +32,8 @@ def process_batch(p, input_dir, output_dir, args):
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
+ if state.skipped:
+ state.skipped = False
if state.interrupted:
break
diff --git a/modules/processing.py b/modules/processing.py
index f773a30e..8240ee27 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -141,6 +141,7 @@ class Processed:
self.all_subseeds = all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
+
def js(self):
obj = {
"prompt": self.prompt,
@@ -312,6 +313,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
+ modules.sd_hijack.model_hijack.clear_comments()
comments = {}
@@ -349,6 +351,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
state.job_count = p.n_iter
for n in range(p.n_iter):
+ if state.skipped:
+ state.skipped = False
+
if state.interrupted:
break
@@ -375,7 +380,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
- if state.interrupted:
+ if state.interrupted or state.skipped:
# if we are interruped, sample returns just noise
# use the image collected previously in sampler loop
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index f00256f2..15666073 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
conds_list.append(conds_for_batch)
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index d68f89cc..335a2bcf 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -18,16 +18,17 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
-
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
-
- if cmd_opts.opt_split_attention_v1:
+ if cmd_opts.xformers and shared.xformers_available and not torch.version.hip:
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ elif cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
+ ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
@@ -37,6 +38,13 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+def get_target_prompt_token_count(token_count):
+ if token_count < 75:
+ return 75
+
+ return math.ceil(token_count / 10) * 10
+
+
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -82,10 +90,12 @@ class StableDiffusionModelHijack:
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros'
+ def clear_comments(self):
+ self.comments = []
+
def tokenize(self, text):
- max_length = self.clip.max_length - 2
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
- return remade_batch_tokens[0], token_count, max_length
+ return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
@@ -94,7 +104,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.wrapped = wrapped
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
- self.max_length = wrapped.max_length
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
@@ -116,7 +125,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
- maxlen = self.wrapped.max_length
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
@@ -148,19 +156,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ prompt_target_length = get_target_prompt_token_count(token_count)
+ tokens_to_add = prompt_target_length - len(remade_tokens) + 1
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+ remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add
+ multipliers = [1.0] + multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
@@ -177,7 +178,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
- remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
@@ -191,7 +193,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
- maxlen = self.wrapped.max_length
+ maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
@@ -263,17 +265,24 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes
- self.hijack.comments = hijack_comments
+ self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
- tokens = torch.asarray(remade_batch_tokens).to(device)
- outputs = self.wrapped.transformer(input_ids=tokens)
+ target_token_count = get_target_prompt_token_count(token_count) + 2
+
+ position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76]
+ position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
+
+ remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
+ tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
+ outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers = torch.asarray(batch_multipliers).to(device)
+ batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
+ batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index d9cca485..d23d733b 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,4 +1,7 @@
import math
+import sys
+import traceback
+
import torch
from torch import einsum
@@ -7,18 +10,37 @@ from einops import rearrange
from modules import shared
+if shared.cmd_opts.xformers:
+ try:
+ import xformers.ops
+ import functorch
+ xformers._is_functorch_available = True
+ shared.xformers_available = True
+ except Exception:
+ print("Cannot import xformers", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
- q = self.to_q(x)
+ q_in = self.to_q(x)
context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
+
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
del context, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
@@ -31,6 +53,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
+ del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@@ -105,6 +128,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
+def xformers_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
+
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
+ return self.to_out(out)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -167,3 +209,13 @@ def cross_attention_attnblock_forward(self, x):
h3 += x
return h3
+
+def xformers_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_).contiguous()
+ k1 = self.k(h_).contiguous()
+ v = self.v(h_).contiguous()
+ out = xformers.ops.memory_efficient_attention(q1, k1, v)
+ out = self.proj_out(out)
+ return x+out
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 8f794b47..9409d070 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -122,7 +122,11 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
pl_sd = torch.load(checkpoint_file, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
- sd = pl_sd["state_dict"]
+
+ if "state_dict" in pl_sd:
+ sd = pl_sd["state_dict"]
+ else:
+ sd = pl_sd
model.load_state_dict(sd, strict=False)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index df17e93c..eade0dbb 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -106,7 +106,7 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
- if state.interrupted:
+ if state.interrupted or state.skipped:
break
yield x
@@ -142,6 +142,16 @@ class VanillaStableDiffusionSampler:
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
+ # for DDIM, shapes must match, we can't just process cond and uncond independently;
+ # filling unconditional_conditioning with repeats of the last vector to match length is
+ # not 100% correct but should work well enough
+ if unconditional_conditioning.shape[1] < cond.shape[1]:
+ last_vector = unconditional_conditioning[:, -1:]
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
+ unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
+ elif unconditional_conditioning.shape[1] > cond.shape[1]:
+ unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
+
if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
@@ -221,18 +231,29 @@ class CFGDenoiser(torch.nn.Module):
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
- cond_in = torch.cat([tensor, uncond])
- if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
+ if tensor.shape[1] == uncond.shape[1]:
+ cond_in = torch.cat([tensor, uncond])
+
+ if shared.batch_cond_uncond:
+ x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
+ else:
+ x_out = torch.zeros_like(x_in)
+ for batch_offset in range(0, x_out.shape[0], batch_size):
+ a = batch_offset
+ b = a + batch_size
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
else:
x_out = torch.zeros_like(x_in)
- for batch_offset in range(0, x_out.shape[0], batch_size):
+ batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
- b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
+ b = min(a + batch_size, tensor.shape[0])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
+
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
- denoised_uncond = x_out[-batch_size:]
+ denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
@@ -254,7 +275,7 @@ def extended_trange(sampler, count, *args, **kwargs):
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
- if state.interrupted:
+ if state.interrupted or state.skipped:
break
if sampler.stop_at is not None and x > sampler.stop_at:
diff --git a/modules/shared.py b/modules/shared.py
index 879d8424..02cb2722 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -43,6 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
+parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
@@ -73,7 +74,7 @@ device = devices.device
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
-
+xformers_available = False
config_filename = cmd_opts.ui_settings_file
hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks'))
@@ -84,6 +85,7 @@ def selected_hypernetwork():
class State:
+ skipped = False
interrupted = False
job = ""
job_no = 0
@@ -96,6 +98,9 @@ class State:
current_image_sampling_step = 0
textinfo = None
+ def skip(self):
+ self.skipped = True
+
def interrupt(self):
self.interrupted = True
@@ -118,8 +123,6 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
-# This was moved to webui.py with the other model "setup" calls.
-# modules.sd_models.list_models()
def realesrgan_models_names():
diff --git a/modules/ui.py b/modules/ui.py
index 4e136ef5..30583fe9 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -192,6 +192,7 @@ def wrap_gradio_call(func, extra_outputs=None):
# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
+ shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
@@ -417,9 +418,16 @@ def create_toprow(is_img2img):
with gr.Column(scale=1):
with gr.Row():
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+ skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
interrupt.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
@@ -952,7 +960,7 @@ def create_ui(wrap_gradio_gpu_call):
custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
- save_as_half = gr.Checkbox(value=False, label="Safe as float16")
+ save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):