diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/extras.py | 57 | ||||
-rw-r--r-- | modules/images.py | 1 | ||||
-rw-r--r-- | modules/processing.py | 18 | ||||
-rw-r--r-- | modules/sd_samplers.py | 28 | ||||
-rw-r--r-- | modules/shared.py | 12 | ||||
-rw-r--r-- | modules/ui.py | 59 |
6 files changed, 153 insertions, 22 deletions
diff --git a/modules/extras.py b/modules/extras.py index 382ffa7d..c4ee2b62 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -3,6 +3,9 @@ import os import numpy as np
from PIL import Image
+import torch
+import tqdm
+
from modules import processing, shared, images, devices
from modules.shared import opts
import modules.gfpgan_model
@@ -135,3 +138,57 @@ def run_pnginfo(image): info = f"<div><p>{message}<p></div>"
return '', geninfo, info
+
+
+def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
+ # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
+ def weighted_sum(theta0, theta1, alpha):
+ return ((1 - alpha) * theta0) + (alpha * theta1)
+
+ # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
+ def sigmoid(theta0, theta1, alpha):
+ alpha = alpha * alpha * (3 - (2 * alpha))
+ return theta0 + ((theta1 - theta0) * alpha)
+
+ if os.path.exists(modelname_0):
+ model0_filename = modelname_0
+ modelname_0 = os.path.splitext(os.path.basename(modelname_0))[0]
+ else:
+ model0_filename = 'models/' + modelname_0 + '.ckpt'
+
+ if os.path.exists(modelname_1):
+ model1_filename = modelname_1
+ modelname_1 = os.path.splitext(os.path.basename(modelname_1))[0]
+ else:
+ model1_filename = 'models/' + modelname_1 + '.ckpt'
+
+ print(f"Loading {model0_filename}...")
+ model_0 = torch.load(model0_filename, map_location='cpu')
+
+ print(f"Loading {model1_filename}...")
+ model_1 = torch.load(model1_filename, map_location='cpu')
+
+ theta_0 = model_0['state_dict']
+ theta_1 = model_1['state_dict']
+
+ theta_funcs = {
+ "Weighted Sum": weighted_sum,
+ "Sigmoid": sigmoid,
+ }
+ theta_func = theta_funcs[interp_method]
+
+ print(f"Merging...")
+ for key in tqdm.tqdm(theta_0.keys()):
+ if 'model' in key and key in theta_1:
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount)
+
+ for key in theta_1.keys():
+ if 'model' in key and key not in theta_0:
+ theta_0[key] = theta_1[key]
+
+ output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-' + interp_method.replace(" ", "_") + '-' + str(interp_amount) + '-merged.ckpt'
+ print(f"Saving to {output_modelname}...")
+ torch.save(model_0, output_modelname)
+
+ print(f"Checkpoint saved.")
+ return "Checkpoint saved to " + output_modelname
diff --git a/modules/images.py b/modules/images.py index ae0e6304..9458bf8d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -295,6 +295,7 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())
+ x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
if cmd_opts.hide_ui_dir_config:
diff --git a/modules/processing.py b/modules/processing.py index 0246e094..8d043f4d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -78,7 +78,14 @@ class StableDiffusionProcessing: self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
-
+
+ self.ddim_eta = opts.ddim_eta
+ self.ddim_discretize = opts.ddim_discretize
+ self.s_churn = opts.s_churn
+ self.s_tmin = opts.s_tmin
+ self.s_tmax = float('inf') # not representable as a standard ui option
+ self.s_noise = opts.s_noise
+
if not seed_enable_extras:
self.subseed = -1
self.subseed_strength = 0
@@ -117,6 +124,13 @@ class Processed: self.extra_generation_params = p.extra_generation_params
self.index_of_first_image = index_of_first_image
+ self.ddim_eta = p.ddim_eta
+ self.ddim_discretize = p.ddim_discretize
+ self.s_churn = p.s_churn
+ self.s_tmin = p.s_tmin
+ self.s_tmax = p.s_tmax
+ self.s_noise = p.s_noise
+
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
@@ -406,7 +420,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p)
+ images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 1fc9d18c..666ee1ee 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -37,6 +37,11 @@ samplers = [ ]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
+sampler_extra_params = {
+ 'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'],
+ 'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'],
+ 'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'],
+}
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
@@ -120,9 +125,9 @@ class VanillaStableDiffusionSampler: # existing code fails with cetain step counts, like 9
try:
- self.sampler.make_schedule(ddim_num_steps=steps, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception:
- self.sampler.make_schedule(ddim_num_steps=steps+1, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps+1,ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
@@ -149,9 +154,9 @@ class VanillaStableDiffusionSampler: # existing code fails with cetin step counts, like 9
try:
- samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
+ samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
except Exception:
- samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
+ samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
return samples_ddim
@@ -224,6 +229,7 @@ class KDiffusionSampler: self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
+ self.extra_params = sampler_extra_params.get(funcname,[])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
self.sampler_noise_index = 0
@@ -269,7 +275,12 @@ class KDiffusionSampler: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
- return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
+ extra_params_kwargs = {}
+ for val in self.extra_params:
+ if hasattr(p,val):
+ extra_params_kwargs[val] = getattr(p,val)
+
+ return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
@@ -286,7 +297,12 @@ class KDiffusionSampler: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
- samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
+ extra_params_kwargs = {}
+ for val in self.extra_params:
+ if hasattr(p,val):
+ extra_params_kwargs[val] = getattr(p,val)
+
+ samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples
diff --git a/modules/shared.py b/modules/shared.py index c32da110..84302438 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -66,7 +66,7 @@ class State: job = ""
job_no = 0
job_count = 0
- job_timestamp = 0
+ job_timestamp = '0'
sampling_step = 0
sampling_steps = 0
current_latent = None
@@ -80,6 +80,7 @@ class State: self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
+
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
@@ -169,7 +170,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
- "ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
+ "ldsr_pre_down": OptionInfo(1, "LDSR Pre-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
@@ -219,6 +220,13 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
}))
+options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
+ "ddim_eta": OptionInfo(0.0, "DDIM eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform','quad']}),
+ 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+}))
class Options:
data = None
diff --git a/modules/ui.py b/modules/ui.py index 3b9c8525..9a3d69c8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -50,6 +50,7 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
+.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
@@ -398,7 +399,7 @@ def setup_progressbar(progressbar, preview, id_part): )
-def create_ui(txt2img, img2img, run_extras, run_pnginfo):
+def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
@@ -569,13 +570,13 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False)
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False)
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
with gr.Row():
- mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask")
+ mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
@@ -858,6 +859,33 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): outputs=[html, generation_info, html2],
)
+ with gr.Blocks() as modelmerger_interface:
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='panel'):
+ gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>/models</b> directory.</p>")
+
+ modelname_0 = gr.Textbox(elem_id="modelmerger_modelname_0", label="Model Name (to)")
+ modelname_1 = gr.Textbox(elem_id="modelmerger_modelname_1", label="Model Name (from)")
+ interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid"], value="Weighted Sum", label="Interpolation Method")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
+ submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
+
+ with gr.Column(variant='panel'):
+ submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
+
+ submit.click(
+ fn=run_modelmerger,
+ inputs=[
+ modelname_0,
+ modelname_1,
+ interp_method,
+ interp_amount
+ ],
+ outputs=[
+ submit_result,
+ ]
+ )
+
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -955,6 +983,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): (img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
+ (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(settings_interface, "Settings", "settings"),
]
@@ -975,6 +1004,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid):
interface.render()
+
+ if os.path.exists(os.path.join(script_path, "notification.mp3")):
+ audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
@@ -983,18 +1015,21 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): outputs=[result, text_settings],
)
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
+ txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
+ img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
send_to_img2img.click(
- fn=lambda x: (image_from_url_text(x)),
- _js="extract_image_from_gallery_img2img",
- inputs=[txt2img_gallery],
- outputs=[init_img],
+ fn=lambda img, *args: (image_from_url_text(img),*args),
+ _js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
+ inputs=[txt2img_gallery] + txt2img_fields,
+ outputs=[init_img] + img2img_fields,
)
send_to_inpaint.click(
- fn=lambda x: (image_from_url_text(x)),
- _js="extract_image_from_gallery_inpaint",
- inputs=[txt2img_gallery],
- outputs=[init_img_with_mask],
+ fn=lambda x, *args: (image_from_url_text(x), *args),
+ _js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
+ inputs=[txt2img_gallery] + txt2img_fields,
+ outputs=[init_img_with_mask] + img2img_fields,
)
img2img_send_to_img2img.click(
|