diff options
40 files changed, 1612 insertions, 428 deletions
@@ -25,3 +25,4 @@ __pycache__ /.idea notification.mp3 /SwinIR +/textual_inversion @@ -11,12 +11,12 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - One click install and run script (but you still must install python and git)
- Outpainting
- Inpainting
-- Prompt
-- Stable Diffusion upscale
+- Prompt Matrix
+- Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to
- - a man in a ((txuedo)) - will pay more attentinoto tuxedo
- - a man in a (txuedo:1.21) - alternative syntax
-- Loopback, run img2img procvessing multiple times
+ - a man in a ((tuxedo)) - will pay more attention to tuxedo
+ - a man in a (tuxedo:1.21) - alternative syntax
+- Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion
- have as many embeddings as you want and use any names you like for them
@@ -35,15 +35,15 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - 4GB video card support (also reports of 2GB working)
- Correct seeds for batches
- Prompt length validation
- - get length of prompt in tokensas you type
- - get a warning after geenration if some text was truncated
+ - get length of prompt in tokens as you type
+ - get a warning after generation if some text was truncated
- Generation parameters
- parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings
- Settings page
-- Running arbitrary python code from UI (must run with commandline flag to enable)
+- Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
@@ -113,6 +113,7 @@ The documentation was moved from this README over to the project's [wiki](https: - LDSR - https://github.com/Hafiidz/latent-diffusion
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
+- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
diff --git a/javascript/hints.js b/javascript/hints.js index 84694eeb..e72e9338 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -15,7 +15,7 @@ titles = { "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u{1f3a8}": "Add a random artist to the prompt.", "\u2199\ufe0f": "Read generation parameters from prompt into user interface.", - "\uD83D\uDCC2": "Open images output directory", + "\u{1f4c2}": "Open images output directory", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 21f25b38..1e297abb 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte onUiUpdate(function(){ check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') + check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery') }) function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js new file mode 100644 index 00000000..8061be08 --- /dev/null +++ b/javascript/textualInversion.js @@ -0,0 +1,8 @@ +
+
+function start_training_textual_inversion(){
+ requestProgress('ti')
+ gradioApp().querySelector('#ti_error').innerHTML=''
+
+ return args_to_array(arguments)
+}
diff --git a/javascript/ui.js b/javascript/ui.js index bfe02410..b1053201 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -199,6 +199,26 @@ let txt2img_textarea, img2img_textarea = undefined; let wait_time = 800 let token_timeout; +function update_txt2img_tokens(...args) { + update_token_counter("txt2img_token_button") + if (args.length == 2) + return args[0] + return args; +} + +function update_img2img_tokens(...args) { + update_token_counter("img2img_token_button") + if (args.length == 2) + return args[0] + return args; +} + +function update_token_counter(button_id) { + if (token_timeout) + clearTimeout(token_timeout); + token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); +} + function submit_prompt(event, generate_button_id) { if (event.altKey && event.keyCode === 13) { event.preventDefault(); @@ -207,8 +227,7 @@ function submit_prompt(event, generate_button_id) { } } -function update_token_counter(button_id) { - if (token_timeout) - clearTimeout(token_timeout); - token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); +function restart_reload(){ + document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; + setTimeout(function(){location.reload()},2000) } @@ -15,6 +15,7 @@ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
+clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
@@ -111,6 +112,9 @@ if not skip_torch_cuda_test: if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan")
+if not is_installed("clip"):
+ run_pip(f"install {clip_package}", "clip")
+
os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
diff --git a/modules/devices.py b/modules/devices.py index 08bb26d6..5d9c7a07 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -32,10 +32,9 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") - device = get_optimal_device() device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device - +dtype = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index ea91abfe..4aed9283 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -73,8 +73,8 @@ def fix_model_layers(crt_model, pretrained_net): class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
- self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
- self.model_name = "ESRGAN 4x"
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
+ self.model_name = "ESRGAN_4x"
self.scalers = []
self.user_path = dirname
self.model_path = os.path.join(models_path, self.name)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index fcd8544a..f1a564b7 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -97,11 +97,7 @@ def setup_model(dirname): return "GFPGAN"
def restore(self, np_image):
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- return np_image
+ return gfpgan_fix_faces(np_image)
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
diff --git a/modules/images.py b/modules/images.py index f1aed5d6..bba55158 100644 --- a/modules/images.py +++ b/modules/images.py @@ -311,7 +311,12 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
- x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
+
+ #currently disabled if using the save button, will work otherwise
+ # if enabled it will cause a bug because styles is not included in the save_files data dictionary
+ if hasattr(p, "styles"):
+ x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"] or "None"), replace_spaces=False))
+
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
@@ -374,7 +379,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs:
- dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt)
+ dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True)
diff --git a/modules/img2img.py b/modules/img2img.py index 03e934e9..2ff8e261 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -23,8 +23,10 @@ def process_batch(p, input_dir, output_dir, args): print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
+ save_normally = output_dir == ''
+
p.do_not_save_grid = True
- p.do_not_save_samples = True
+ p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter
@@ -48,7 +50,8 @@ def process_batch(p, input_dir, output_dir, args): left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}"
- processed_image.save(os.path.join(output_dir, filename))
+ if not save_normally:
+ processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
@@ -103,7 +106,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert,
)
- print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
+
+ if shared.cmd_opts.enable_console_prompts:
+ print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
p.extra_generation_params["Mask blur"] = mask_blur
diff --git a/modules/interrogate.py b/modules/interrogate.py index f62a4745..eed87144 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.")
+
class InterrogateModels:
blip_model = None
clip_model = None
diff --git a/modules/modelloader.py b/modules/modelloader.py index 8c862b42..b0f2f33d 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -5,7 +5,6 @@ import importlib from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url - from modules import shared from modules.upscaler import Upscaler from modules.paths import script_path, models_path @@ -43,7 +42,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None for place in places: if os.path.exists(place): for file in glob.iglob(place + '**/**', recursive=True): - full_path = os.path.join(place, file) + full_path = file if os.path.isdir(full_path): continue if len(ext_filter) != 0: @@ -121,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): def load_upscalers(): + sd = shared.script_path + # We can only do this 'magic' method to dynamically load upscalers if they are referenced, + # so we'll try to import any _model.py files before looking in __subclasses__ + modules_dir = os.path.join(sd, "modules") + for file in os.listdir(modules_dir): + if "_model.py" in file: + model_name = file.replace("_model.py", "") + full_model = f"modules.{model_name}_model" + try: + importlib.import_module(full_model) + except: + pass datas = [] + c_o = vars(shared.cmd_opts) for cls in Upscaler.__subclasses__(): name = cls.__name__ module_name = cls.__module__ module = importlib.import_module(module_name) class_ = getattr(module, name) - cmd_name = f"{name.lower().replace('upscaler', '')}-models-path" + cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" opt_string = None try: - opt_string = shared.opts.__getattr__(cmd_name) + if cmd_name in c_o: + opt_string = c_o[cmd_name] except: pass scaler = class_(opt_string) diff --git a/modules/paths.py b/modules/paths.py index ceb80417..606f7d66 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,7 +20,6 @@ path_dirs = [ (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
- (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
diff --git a/modules/processing.py b/modules/processing.py index 7eeb5191..0a4b6198 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -56,7 +56,7 @@ class StableDiffusionProcessing: self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
- self.styles: str = styles
+ self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
@@ -79,7 +79,7 @@ class StableDiffusionProcessing: self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
-
+ self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin
@@ -130,7 +130,7 @@ class Processed: self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
-
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
@@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
- "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
+ "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
}
generation_params.update(p.extra_generation_params)
@@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: fix_seed(p)
- os.makedirs(p.outpath_samples, exist_ok=True)
- os.makedirs(p.outpath_grids, exist_ok=True)
+ if p.outpath_samples is not None:
+ os.makedirs(p.outpath_samples, exist_ok=True)
+
+ if p.outpath_grids is not None:
+ os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
@@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir):
- model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = []
output_images = []
diff --git a/modules/scripts.py b/modules/scripts.py index 7c3bd5e7..45230f9a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -162,6 +162,40 @@ class ScriptRunner: return processed
+ def reload_sources(self):
+ for si, script in list(enumerate(self.scripts)):
+ with open(script.filename, "r", encoding="utf8") as file:
+ args_from = script.args_from
+ args_to = script.args_to
+ filename = script.filename
+ text = file.read()
+
+ from types import ModuleType
+
+ compiled = compile(text, filename, 'exec')
+ module = ModuleType(script.filename)
+ exec(compiled, module.__dict__)
+
+ for key, script_class in module.__dict__.items():
+ if type(script_class) == type and issubclass(script_class, Script):
+ self.scripts[si] = script_class()
+ self.scripts[si].filename = filename
+ self.scripts[si].args_from = args_from
+ self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
+def reload_script_body_only():
+ scripts_txt2img.reload_sources()
+ scripts_img2img.reload_sources()
+
+
+def reload_scripts(basedir):
+ global scripts_txt2img, scripts_img2img
+
+ scripts_data.clear()
+ load_scripts(basedir)
+
+ scripts_txt2img = ScriptRunner()
+ scripts_img2img = ScriptRunner()
diff --git a/modules/scunet_model.py b/modules/scunet_model.py new file mode 100644 index 00000000..7987ac14 --- /dev/null +++ b/modules/scunet_model.py @@ -0,0 +1,90 @@ +import os.path +import sys +import traceback + +import PIL.Image +import numpy as np +import torch +from basicsr.utils.download_util import load_file_from_url + +import modules.upscaler +from modules import shared, modelloader +from modules.paths import models_path +from modules.scunet_model_arch import SCUNet as net + + +class UpscalerScuNET(modules.upscaler.Upscaler): + def __init__(self, dirname): + self.name = "ScuNET" + self.model_path = os.path.join(models_path, self.name) + self.model_name = "ScuNET GAN" + self.model_name2 = "ScuNET PSNR" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" + self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" + self.user_path = dirname + super().__init__() + model_paths = self.find_models(ext_filter=[".pth"]) + scalers = [] + add_model2 = True + for file in model_paths: + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + if name == self.model_name2 or file == self.model_url2: + add_model2 = False + try: + scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) + scalers.append(scaler_data) + except Exception: + print(f"Error loading ScuNET model: {file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + if add_model2: + scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) + scalers.append(scaler_data2) + self.scalers = scalers + + def do_upscale(self, img: PIL.Image, selected_file): + torch.cuda.empty_cache() + + model = self.load_model(selected_file) + if model is None: + return img + + device = shared.device + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(shared.device) + + img = img.to(device) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + torch.cuda.empty_cache() + return PIL.Image.fromarray(output, 'RGB') + + def load_model(self, path: str): + device = shared.device + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, + progress=True) + else: + filename = path + if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: + print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) + return None + + model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) + model.load_state_dict(torch.load(filename), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + + return model + diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py new file mode 100644 index 00000000..972a2639 --- /dev/null +++ b/modules/scunet_model_arch.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_, DropPath + + +class WMSA(nn.Module): + """ Self-attention module in Swin Transformer + """ + + def __init__(self, input_dim, output_dim, head_dim, window_size, type): + super(WMSA, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.head_dim = head_dim + self.scale = self.head_dim ** -0.5 + self.n_heads = input_dim // head_dim + self.window_size = window_size + self.type = type + self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) + + self.relative_position_params = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) + + self.linear = nn.Linear(self.input_dim, self.output_dim) + + trunc_normal_(self.relative_position_params, std=.02) + self.relative_position_params = torch.nn.Parameter( + self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, + 2).transpose( + 0, 1)) + + def generate_mask(self, h, w, p, shift): + """ generating the mask of SW-MSA + Args: + shift: shift parameters in CyclicShift. + Returns: + attn_mask: should be (1 1 w p p), + """ + # supporting sqaure. + attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) + if self.type == 'W': + return attn_mask + + s = p - shift + attn_mask[-1, :, :s, :, s:, :] = True + attn_mask[-1, :, s:, :, :s, :] = True + attn_mask[:, -1, :, :s, :, s:] = True + attn_mask[:, -1, :, s:, :, :s] = True + attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') + return attn_mask + + def forward(self, x): + """ Forward pass of Window Multi-head Self-attention module. + Args: + x: input tensor with shape of [b h w c]; |