aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-05 00:25:37 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-09-05 00:25:37 +0000
commita8a58dbac7b205ae90664c3b249d60e4baa2855c (patch)
tree68d8912439be31cb6445250319d6133050cbba79
parentf91d0c3d19ac2b849cd15a519f711f9281e782cd (diff)
downloadstable-diffusion-webui-gfx803-a8a58dbac7b205ae90664c3b249d60e4baa2855c.tar.gz
stable-diffusion-webui-gfx803-a8a58dbac7b205ae90664c3b249d60e4baa2855c.tar.bz2
stable-diffusion-webui-gfx803-a8a58dbac7b205ae90664c3b249d60e4baa2855c.zip
re-integrated tiling option as a UI element
-rw-r--r--modules/img2img.py3
-rw-r--r--modules/processing.py6
-rw-r--r--modules/sd_hijack.py20
-rw-r--r--modules/shared.py2
-rw-r--r--modules/txt2img.py5
-rw-r--r--modules/ui.py4
-rw-r--r--webui.py5
7 files changed, 34 insertions, 11 deletions
diff --git a/modules/img2img.py b/modules/img2img.py
index b1ef1326..e6707f96 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts
-def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
+def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
is_inpaint = mode == 1
is_loopback = mode == 2
is_upscale = mode == 3
@@ -37,6 +37,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
width=width,
height=height,
use_GFPGAN=use_GFPGAN,
+ tiling=tiling,
init_images=[image],
mask=mask,
mask_blur=mask_blur,
diff --git a/modules/processing.py b/modules/processing.py
index adc5d851..a5b2afb9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -9,6 +9,7 @@ import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random
+import modules.sd_hijack
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
@@ -28,7 +29,7 @@ def torch_gc():
class StableDiffusionProcessing:
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, use_GFPGAN=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -44,6 +45,7 @@ class StableDiffusionProcessing:
self.width: int = width
self.height: int = height
self.use_GFPGAN: bool = use_GFPGAN
+ self.tiling: bool = tiling
self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params
@@ -110,6 +112,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
os.makedirs(p.outpath_samples, exist_ok=True)
os.makedirs(p.outpath_grids, exist_ok=True)
+ modules.sd_hijack.model_hijack.apply_circular(p.tiling)
+
comments = []
if type(prompt) == list:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 9779c30c..2d26b5f7 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -49,6 +49,8 @@ class StableDiffusionModelHijack:
fixes = None
comments = []
dir_mtime = None
+ layers = None
+ circular_enabled = False
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
@@ -105,6 +107,24 @@ class StableDiffusionModelHijack:
if cmd_opts.opt_split_attention:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ def flatten(el):
+ flattened = [flatten(children) for children in el.children()]
+ res = [el]
+ for c in flattened:
+ res += c
+ return res
+
+ self.layers = flatten(m)
+
+ def apply_circular(self, enable):
+ if self.circular_enabled == enable:
+ return
+
+ self.circular_enabled = enable
+
+ for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
+ layer.padding_mode = 'circular' if enable else 'zeros'
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
diff --git a/modules/shared.py b/modules/shared.py
index 0722185d..9e744f6c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -30,8 +30,6 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
parser.add_argument("--opt-split-attention", action='store_true', help="enable optimization that reduced vram usage by a lot for about 10% decrease in performance")
-parser.add_argument("--tiling", action='store_true', help="causes the model to generate images that can be tiled")
-
cmd_opts = parser.parse_args()
cpu = torch.device("cpu")
diff --git a/modules/txt2img.py b/modules/txt2img.py
index fb65a7f6..dfce49ff 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
+def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -21,7 +21,8 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u
cfg_scale=cfg_scale,
width=width,
height=height,
- use_GFPGAN=use_GFPGAN
+ use_GFPGAN=use_GFPGAN,
+ tiling=tiling,
)
processed = modules.scripts.scripts_txt2img.run(p, *args)
diff --git a/modules/ui.py b/modules/ui.py
index 4119369e..a2f1124e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -155,6 +155,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Row():
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
+ tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row():
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
@@ -195,6 +196,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
steps,
sampler_index,
use_gfpgan,
+ tiling,
batch_count,
batch_size,
cfg_scale,
@@ -256,6 +258,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Row():
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
+ tiling = gr.Checkbox(label='Tiling', value=False)
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
with gr.Row():
@@ -339,6 +342,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
mask_blur,
inpainting_fill,
use_gfpgan,
+ tiling,
switch_mode,
batch_count,
batch_size,
diff --git a/webui.py b/webui.py
index 6f483482..dbc9dd54 100644
--- a/webui.py
+++ b/webui.py
@@ -140,11 +140,6 @@ try:
except Exception:
pass
-
-if cmd_opts.tiling:
- # this has to be done before the model is loaded
- modules.sd_hijack.add_circular_option_to_conv_2d()
-
sd_config = OmegaConf.load(cmd_opts.config)
shared.sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
shared.sd_model = (shared.sd_model if cmd_opts.no_half else shared.sd_model.half())