From 29e74d6e71826da9a3fe3c5790fed1329fc4d1e8 Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 20 Oct 2022 16:26:16 +0200 Subject: Add support for Tensorboard for training embeddings --- modules/shared.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index faede821..2c6341f7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -254,6 +254,10 @@ options_templates.update(options_section(('training', "Training"), { "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), + "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."), + "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."), + "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), + })) options_templates.update(options_section(('sd', "Stable Diffusion"), { -- cgit v1.2.3 From 75c4511e6b81ae8fb0dbd932043e8eb35cd09f72 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 29 Nov 2022 10:28:41 +0800 Subject: add AltDiffusion to webui Signed-off-by: zhaohu xing <920232796@qq.com> --- configs/altdiffusion/ad-inference.yaml | 72 ++ configs/stable-diffusion/v1-inference.yaml | 71 ++ ldm/data/__init__.py | 0 ldm/data/base.py | 23 + ldm/data/imagenet.py | 394 +++++++ ldm/data/lsun.py | 92 ++ ldm/lr_scheduler.py | 98 ++ ldm/models/autoencoder.py | 443 ++++++++ ldm/models/diffusion/__init__.py | 0 ldm/models/diffusion/classifier.py | 267 +++++ ldm/models/diffusion/ddim.py | 241 +++++ ldm/models/diffusion/ddpm.py | 1445 +++++++++++++++++++++++++ ldm/models/diffusion/dpm_solver/__init__.py | 1 + ldm/models/diffusion/dpm_solver/dpm_solver.py | 1184 ++++++++++++++++++++ ldm/models/diffusion/dpm_solver/sampler.py | 82 ++ ldm/models/diffusion/plms.py | 236 ++++ ldm/modules/attention.py | 261 +++++ ldm/modules/diffusionmodules/__init__.py | 0 ldm/modules/diffusionmodules/model.py | 835 ++++++++++++++ ldm/modules/diffusionmodules/openaimodel.py | 961 ++++++++++++++++ ldm/modules/diffusionmodules/util.py | 267 +++++ ldm/modules/distributions/__init__.py | 0 ldm/modules/distributions/distributions.py | 92 ++ ldm/modules/ema.py | 76 ++ ldm/modules/encoders/__init__.py | 0 ldm/modules/encoders/modules.py | 234 ++++ ldm/modules/encoders/xlmr.py | 137 +++ ldm/modules/image_degradation/__init__.py | 2 + ldm/modules/image_degradation/bsrgan.py | 730 +++++++++++++ ldm/modules/image_degradation/bsrgan_light.py | 650 +++++++++++ ldm/modules/image_degradation/utils/test.png | Bin 0 -> 441072 bytes ldm/modules/image_degradation/utils_image.py | 916 ++++++++++++++++ ldm/modules/losses/__init__.py | 1 + ldm/modules/losses/contperceptual.py | 111 ++ ldm/modules/losses/vqperceptual.py | 167 +++ ldm/modules/x_transformer.py | 641 +++++++++++ ldm/util.py | 203 ++++ modules/devices.py | 4 +- modules/sd_hijack.py | 23 +- modules/shared.py | 6 +- 40 files changed, 10957 insertions(+), 9 deletions(-) create mode 100644 configs/altdiffusion/ad-inference.yaml create mode 100644 configs/stable-diffusion/v1-inference.yaml create mode 100644 ldm/data/__init__.py create mode 100644 ldm/data/base.py create mode 100644 ldm/data/imagenet.py create mode 100644 ldm/data/lsun.py create mode 100644 ldm/lr_scheduler.py create mode 100644 ldm/models/autoencoder.py create mode 100644 ldm/models/diffusion/__init__.py create mode 100644 ldm/models/diffusion/classifier.py create mode 100644 ldm/models/diffusion/ddim.py create mode 100644 ldm/models/diffusion/ddpm.py create mode 100644 ldm/models/diffusion/dpm_solver/__init__.py create mode 100644 ldm/models/diffusion/dpm_solver/dpm_solver.py create mode 100644 ldm/models/diffusion/dpm_solver/sampler.py create mode 100644 ldm/models/diffusion/plms.py create mode 100644 ldm/modules/attention.py create mode 100644 ldm/modules/diffusionmodules/__init__.py create mode 100644 ldm/modules/diffusionmodules/model.py create mode 100644 ldm/modules/diffusionmodules/openaimodel.py create mode 100644 ldm/modules/diffusionmodules/util.py create mode 100644 ldm/modules/distributions/__init__.py create mode 100644 ldm/modules/distributions/distributions.py create mode 100644 ldm/modules/ema.py create mode 100644 ldm/modules/encoders/__init__.py create mode 100644 ldm/modules/encoders/modules.py create mode 100644 ldm/modules/encoders/xlmr.py create mode 100644 ldm/modules/image_degradation/__init__.py create mode 100644 ldm/modules/image_degradation/bsrgan.py create mode 100644 ldm/modules/image_degradation/bsrgan_light.py create mode 100644 ldm/modules/image_degradation/utils/test.png create mode 100644 ldm/modules/image_degradation/utils_image.py create mode 100644 ldm/modules/losses/__init__.py create mode 100644 ldm/modules/losses/contperceptual.py create mode 100644 ldm/modules/losses/vqperceptual.py create mode 100644 ldm/modules/x_transformer.py create mode 100644 ldm/util.py (limited to 'modules/shared.py') diff --git a/configs/altdiffusion/ad-inference.yaml b/configs/altdiffusion/ad-inference.yaml new file mode 100644 index 00000000..1b11b63e --- /dev/null +++ b/configs/altdiffusion/ad-inference.yaml @@ -0,0 +1,72 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.xlmr.BertSeriesModelWithTransformation + params: + name: "XLMR-Large" \ No newline at end of file diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml new file mode 100644 index 00000000..2e6ef0f2 --- /dev/null +++ b/configs/stable-diffusion/v1-inference.yaml @@ -0,0 +1,71 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: altclip.model.AltCLIPEmbedder \ No newline at end of file diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/data/base.py b/ldm/data/base.py new file mode 100644 index 00000000..b196c2f7 --- /dev/null +++ b/ldm/data/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass \ No newline at end of file diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py new file mode 100644 index 00000000..1c473f9c --- /dev/null +++ b/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/ldm/data/lsun.py b/ldm/data/lsun.py new file mode 100644 index 00000000..6256e457 --- /dev/null +++ b/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py new file mode 100644 index 00000000..be39da9c --- /dev/null +++ b/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py new file mode 100644 index 00000000..6a9c4f45 --- /dev/null +++ b/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py new file mode 100644 index 00000000..67e98b9d --- /dev/null +++ b/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py new file mode 100644 index 00000000..fb31215d --- /dev/null +++ b/ldm/models/diffusion/ddim.py @@ -0,0 +1,241 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py new file mode 100644 index 00000000..bbedd04c --- /dev/null +++ b/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1445 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 00000000..7427f38c --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 00000000..bdb64e0c --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1184 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 00000000..2c42d6f9 --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,82 @@ +"""SAMPLING ONLY.""" + +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type="noise", + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py new file mode 100644 index 00000000..78eeb100 --- /dev/null +++ b/ldm/models/diffusion/plms.py @@ -0,0 +1,236 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py new file mode 100644 index 00000000..f4eff39c --- /dev/null +++ b/ldm/modules/attention.py @@ -0,0 +1,261 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 00000000..533e589a --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,835 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 00000000..fcf95d1e --- /dev/null +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,961 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py new file mode 100644 index 00000000..a952e6c4 --- /dev/null +++ b/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,267 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 00000000..f2b8ef90 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py new file mode 100644 index 00000000..c8c75af4 --- /dev/null +++ b/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py new file mode 100644 index 00000000..ededbe43 --- /dev/null +++ b/ldm/modules/encoders/modules.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) \ No newline at end of file diff --git a/ldm/modules/encoders/xlmr.py b/ldm/modules/encoders/xlmr.py new file mode 100644 index 00000000..beab3fdf --- /dev/null +++ b/ldm/modules/encoders/xlmr.py @@ -0,0 +1,137 @@ +from transformers import BertPreTrainedModel,BertModel,BertConfig +import torch.nn as nn +import torch +from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from transformers import XLMRobertaModel,XLMRobertaTokenizer +from typing import Optional + +class BertSeriesConfig(BertConfig): + def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): + + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + + +class BertSeriesModelWithTransformation(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + config_class = BertSeriesConfig + + def __init__(self, config=None, **kargs): + # modify initialization for autoloading + if config is None: + config = XLMRobertaConfig() + config.attention_probs_dropout_prob= 0.1 + config.bos_token_id=0 + config.eos_token_id=2 + config.hidden_act='gelu' + config.hidden_dropout_prob=0.1 + config.hidden_size=1024 + config.initializer_range=0.02 + config.intermediate_size=4096 + config.layer_norm_eps=1e-05 + config.max_position_embeddings=514 + + config.num_attention_heads=16 + config.num_hidden_layers=24 + config.output_past=True + config.pad_token_id=1 + config.position_embedding_type= "absolute" + + config.type_vocab_size= 1 + config.use_cache=True + config.vocab_size= 250002 + config.project_dim = 768 + config.learn_encoder = False + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size,config.project_dim) + self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + self.pooler = lambda x: x[:,0] + self.post_init() + + def encode(self,c): + device = next(self.parameters()).device + text = self.tokenizer(c, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + text["input_ids"] = torch.tensor(text["input_ids"]).to(device) + text["attention_mask"] = torch.tensor( + text['attention_mask']).to(device) + features = self(**text) + return features['projection_state'] + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) : + r""" + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # last module outputs + sequence_output = outputs[0] + + + # project every module + sequence_output_ln = self.pre_LN(sequence_output) + + # pooler + pooler_output = self.pooler(sequence_output_ln) + pooler_output = self.transformation(pooler_output) + projection_state = self.transformation(outputs.last_hidden_state) + + return { + 'pooler_output':pooler_output, + 'last_hidden_state':outputs.last_hidden_state, + 'hidden_states':outputs.hidden_states, + 'attentions':outputs.attentions, + 'projection_state':projection_state, + 'sequence_out': sequence_output + } + + +class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): + base_model_prefix = 'roberta' + config_class= RobertaSeriesConfig \ No newline at end of file diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py new file mode 100644 index 00000000..7836cada --- /dev/null +++ b/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 00000000..32ef5616 --- /dev/null +++ b/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 00000000..9e1f8239 --- /dev/null +++ b/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 00000000..4249b43d Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 00000000..0175f155 --- /dev/null +++ b/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py new file mode 100644 index 00000000..876d7c5b --- /dev/null +++ b/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py new file mode 100644 index 00000000..672c1e32 --- /dev/null +++ b/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py new file mode 100644 index 00000000..f6998176 --- /dev/null +++ b/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py new file mode 100644 index 00000000..5fc15bf9 --- /dev/null +++ b/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/ldm/util.py b/ldm/util.py new file mode 100644 index 00000000..8ba38853 --- /dev/null +++ b/ldm/util.py @@ -0,0 +1,203 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..f30b6ebc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -36,8 +36,8 @@ def get_optimal_device(): else: return torch.device("cuda") - if has_mps(): - return torch.device("mps") + # if has_mps(): + # return torch.device("mps") return cpu diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..26280fe4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -70,14 +70,19 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + + if shared.text_model_name == "XLMR-Large": + model_embeddings = m.cond_stage_model.roberta.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) + else : + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embeddings, self) - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model - apply_optimizations() + # apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -125,8 +130,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer self.token_mults = {} - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - + try: + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + except: + self.comma_token = None + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 @@ -298,6 +306,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): + if shared.text_model_name == "XLMR-Large": + return self.wrapped.encode(text) + use_old = opts.use_old_emphasis_implementation if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) @@ -359,7 +370,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z = self.wrapped.transformer.text_model.final_layer_norm(z) else: z = outputs.last_hidden_state - + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) diff --git a/modules/shared.py b/modules/shared.py index c93ae2a3..9941d2f4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -21,7 +21,7 @@ from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--config", type=str, default="configs/altdiffusion/ad-inference.yaml", help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) @@ -106,6 +106,10 @@ restricted_opts = { "outdir_txt2img_grids", "outdir_save", } +from omegaconf import OmegaConf +config = OmegaConf.load(f"{cmd_opts.config}") +# XLMR-Large +text_model_name = config.model.params.cond_stage_config.params.name cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access -- cgit v1.2.3 From 9c86fb8cace6d8ac0843e0ddad0ba5ae7f3148c9 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Fri, 2 Dec 2022 16:08:46 +0800 Subject: fix bug Signed-off-by: zhaohu xing <920232796@qq.com> --- modules/shared.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 1408dee3..ac7678c3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -111,7 +111,11 @@ restricted_opts = { from omegaconf import OmegaConf config = OmegaConf.load(f"{cmd_opts.config}") # XLMR-Large -text_model_name = config.model.params.cond_stage_config.params.name +try: + text_model_name = config.model.params.cond_stage_config.params.name + +except : + text_model_name = "stable_diffusion" cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access -- cgit v1.2.3 From 965fc5ac5a6ccdf38342e21c97183011a04e799e Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 6 Dec 2022 16:15:15 +0800 Subject: delete a file Signed-off-by: zhaohu xing <920232796@qq.com> --- .DS_Store | Bin 6148 -> 0 bytes modules/shared.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 .DS_Store (limited to 'modules/shared.py') diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 5008ddfc..00000000 Binary files a/.DS_Store and /dev/null differ diff --git a/modules/shared.py b/modules/shared.py index 522c56c1..8419b531 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -22,7 +22,7 @@ demo = None sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default="configs/altdiffusion/ad-inference.yaml", help="path to config which constructs model",) +parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) -- cgit v1.2.3 From f23a822f1c9cb3bd2e8772c75af429e06515eaef Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 24 Dec 2022 20:45:16 +1100 Subject: feat(api): include job_timestamp in progress --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 8ea3b441..f356dbf7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -171,6 +171,7 @@ class State: "interrupted": self.skipped, "job": self.job, "job_count": self.job_count, + "job_timestamp": self.job_timestamp, "job_no": self.job_no, "sampling_step": self.sampling_step, "sampling_steps": self.sampling_steps, -- cgit v1.2.3 From 893933e05ad267778111b4fad6d1ecb80937afdf Mon Sep 17 00:00:00 2001 From: hitomi Date: Sun, 25 Dec 2022 20:49:25 +0800 Subject: Add memory cache for VAE weights --- modules/sd_vae.py | 31 +++++++++++++++++++++++++------ modules/shared.py | 1 + 2 files changed, 26 insertions(+), 6 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 3856418e..ac71d62d 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,5 +1,6 @@ import torch import os +import collections from collections import namedtuple from modules import shared, devices, script_callbacks from modules.paths import models_path @@ -30,6 +31,7 @@ base_vae = None loaded_vae_file = None checkpoint_info = None +checkpoints_loaded = collections.OrderedDict() def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: @@ -149,13 +151,30 @@ def load_vae(model, vae_file=None): global first_load, vae_dict, vae_list, loaded_vae_file # save_settings = False + cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 + if vae_file: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") - store_base_vae(model) - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - _load_vae_dict(model, vae_dict_1) + if cache_enabled and vae_file in checkpoints_loaded: + # use vae checkpoint cache + print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + store_base_vae(model) + _load_vae_dict(model, checkpoints_loaded[vae_file]) + else: + assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" + print(f"Loading VAE weights from: {vae_file}") + store_base_vae(model) + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + _load_vae_dict(model, vae_dict_1) + + if cache_enabled: + # cache newly loaded vae + checkpoints_loaded[vae_file] = vae_dict_1.copy() + + # clean up cache if limit is reached + if cache_enabled: + while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model + checkpoints_loaded.popitem(last=False) # LRU # If vae used is not in dict, update it # It will be removed on refresh though diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..671d30e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), -- cgit v1.2.3 From 03f486a2399df0a2b24c7aeea72e64f106a87297 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 26 Dec 2022 20:49:33 +0000 Subject: Update shared.py --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..5edb316c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -418,6 +418,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), + 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), })) options_templates.update(options_section((None, "Hidden options"), { -- cgit v1.2.3 From 463048344fc036b262aa132584b65ee6e9fec6cf Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 30 Dec 2022 19:41:47 -0500 Subject: fix shared state dictionary --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..9a13fb60 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -168,7 +168,7 @@ class State: def dict(self): obj = { "skipped": self.skipped, - "interrupted": self.skipped, + "interrupted": self.interrupted, "job": self.job, "job_count": self.job_count, "job_no": self.job_no, -- cgit v1.2.3 From f34c7341720fb2059992926c9f9ae6ff25f7385b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 31 Dec 2022 18:06:35 +0300 Subject: alt-diffusion integration --- configs/alt-diffusion-inference.yaml | 72 ++++++++++++++++++++++++++++++++++ configs/altdiffusion/ad-inference.yaml | 72 ---------------------------------- configs/v1-inference.yaml | 70 +++++++++++++++++++++++++++++++++ modules/sd_hijack.py | 18 +++++---- modules/sd_hijack_clip.py | 14 +++---- modules/sd_hijack_xlmr.py | 34 ++++++++++++++++ modules/shared.py | 10 +---- v1-inference.yaml | 70 --------------------------------- 8 files changed, 192 insertions(+), 168 deletions(-) create mode 100644 configs/alt-diffusion-inference.yaml delete mode 100644 configs/altdiffusion/ad-inference.yaml create mode 100644 configs/v1-inference.yaml create mode 100644 modules/sd_hijack_xlmr.py delete mode 100644 v1-inference.yaml (limited to 'modules/shared.py') diff --git a/configs/alt-diffusion-inference.yaml b/configs/alt-diffusion-inference.yaml new file mode 100644 index 00000000..cfbee72d --- /dev/null +++ b/configs/alt-diffusion-inference.yaml @@ -0,0 +1,72 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: modules.xlmr.BertSeriesModelWithTransformation + params: + name: "XLMR-Large" \ No newline at end of file diff --git a/configs/altdiffusion/ad-inference.yaml b/configs/altdiffusion/ad-inference.yaml deleted file mode 100644 index cfbee72d..00000000 --- a/configs/altdiffusion/ad-inference.yaml +++ /dev/null @@ -1,72 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: modules.xlmr.BertSeriesModelWithTransformation - params: - name: "XLMR-Large" \ No newline at end of file diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml new file mode 100644 index 00000000..d4effe56 --- /dev/null +++ b/configs/v1-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bce23b03..edcbaf52 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -68,6 +68,7 @@ def fix_checkpoint(): ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward + class StableDiffusionModelHijack: fixes = None comments = [] @@ -79,21 +80,22 @@ class StableDiffusionModelHijack: def hijack(self, m): - if shared.text_model_name == "XLMR-Large": + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - + m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - apply_optimizations() + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder: m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - apply_optimizations() - + + apply_optimizations() + self.clip = m.cond_stage_model fix_checkpoint() @@ -109,7 +111,7 @@ class StableDiffusionModelHijack: def undo_hijack(self, m): - if shared.text_model_name == "XLMR-Large": + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: m.cond_stage_model = m.cond_stage_model.wrapped elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 9ea6e1ce..6ec50cca 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -4,7 +4,6 @@ import torch from modules import prompt_parser, devices from modules.shared import opts -import modules.shared as shared def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 @@ -177,9 +176,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): - if shared.text_model_name == "XLMR-Large": - return self.wrapped.encode(text) - use_old = opts.use_old_emphasis_implementation if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) @@ -257,13 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) self.tokenizer = wrapped.tokenizer - if shared.text_model_name == "XLMR-Large": - self.comma_token = None - else : - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + + vocab = self.tokenizer.get_vocab() + + self.comma_token = vocab.get(',', None) self.token_mults = {} - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 for c in text: diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py new file mode 100644 index 00000000..4ac51c38 --- /dev/null +++ b/modules/sd_hijack_xlmr.py @@ -0,0 +1,34 @@ +import open_clip.tokenizer +import torch + +from modules import sd_hijack_clip, devices +from modules.shared import opts + + +class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.id_start = wrapped.config.bos_token_id + self.id_end = wrapped.config.eos_token_id + self.id_pad = wrapped.config.pad_token_id + + self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma + + def encode_with_transformers(self, tokens): + # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a + # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer + # layer to work with - you have to use the last + + attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64) + features = self.wrapped(input_ids=tokens, attention_mask=attention_mask) + z = features['projection_state'] + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.roberta.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + + return embedded diff --git a/modules/shared.py b/modules/shared.py index 2b31e717..715b9169 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -23,7 +23,7 @@ demo = None sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) @@ -108,14 +108,6 @@ restricted_opts = { "outdir_txt2img_grids", "outdir_save", } -from omegaconf import OmegaConf -config = OmegaConf.load(f"{cmd_opts.config}") -# XLMR-Large -try: - text_model_name = config.model.params.cond_stage_config.params.name - -except : - text_model_name = "stable_diffusion" cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access diff --git a/v1-inference.yaml b/v1-inference.yaml deleted file mode 100644 index d4effe56..00000000 --- a/v1-inference.yaml +++ /dev/null @@ -1,70 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder -- cgit v1.2.3 From 29a3a7eb13478297bc7093971b48827ab8246f45 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 01:19:10 +0300 Subject: show sampler selection in dropdown, add option selection to revert to old radio group --- modules/shared.py | 1 + modules/ui.py | 22 +++++++++++++++------- 2 files changed, 16 insertions(+), 7 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 715b9169..948b9542 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -406,6 +406,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index 279b5110..c7b8ea5d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -643,6 +643,19 @@ Requested path was: {f} return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with gr.Row(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + else: + with gr.Group(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + def create_ui(): import modules.img2img import modules.txt2img @@ -660,9 +673,6 @@ def create_ui(): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - - with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): pass @@ -674,8 +684,7 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='panel', elem_id="txt2img_settings"): - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") with gr.Group(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) @@ -875,8 +884,7 @@ def create_ui(): with gr.Row(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) - sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") with gr.Group(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") -- cgit v1.2.3 From 16b9661d2741b241c3964fcbd56559c078b84822 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 09:51:37 +0300 Subject: change karras scheduler sigmas to values recommended by SD from old 0.1 to 10 with an option to revert to old --- modules/sd_samplers.py | 4 +++- modules/shared.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 177b5338..e904d860 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -465,7 +465,9 @@ class KDiffusionSampler: if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) else: sigmas = self.model_wrap.get_sigmas(steps) diff --git a/modules/shared.py b/modules/shared.py index 948b9542..7f430b93 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -368,13 +368,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), - "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) +options_templates.update(options_section(('compatibility', "Compatibility"), { + "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), + "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), +})) + options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), -- cgit v1.2.3 From ef27a18b6b7cb1a8eebdc9b2e88d25baf2c2414d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 19:42:10 +0300 Subject: Hires fix rework --- modules/generation_parameters_copypaste.py | 32 ++++++++++++++ modules/images.py | 24 +++++++++-- modules/processing.py | 68 ++++++++++++------------------ modules/shared.py | 7 ++- modules/txt2img.py | 6 +-- modules/ui.py | 15 +++---- scripts/xy_grid.py | 4 +- 7 files changed, 96 insertions(+), 60 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 8e7f0df0..d6fa822b 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,5 +1,6 @@ import base64 import io +import math import os import re from pathlib import Path @@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None): return None +def restore_old_hires_fix_params(res): + """for infotexts that specify old First pass size parameter, convert it into + width, height, and hr scale""" + + firstpass_width = res.get('First pass size-1', None) + firstpass_height = res.get('First pass size-2', None) + + if firstpass_width is None or firstpass_height is None: + return + + firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height) + width = int(res.get("Size-1", 512)) + height = int(res.get("Size-2", 512)) + + if firstpass_width == 0 or firstpass_height == 0: + # old algorithm for auto-calculating first pass size + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + firstpass_width = math.ceil(scale * width / 64) * 64 + firstpass_height = math.ceil(scale * height / 64) * 64 + + hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height + + res['Size-1'] = firstpass_width + res['Size-2'] = firstpass_height + res['Hires upscale'] = hr_scale + + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -221,6 +251,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hypernet_hash = res.get("Hypernet hash", None) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + restore_old_hires_fix_params(res) + return res diff --git a/modules/images.py b/modules/images.py index f84fd485..c3a5fc8b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts): return draw_grid_annotations(im, width, height, hor_texts, ver_texts) -def resize_image(resize_mode, im, width, height): +def resize_image(resize_mode, im, width, height, upscaler_name=None): + """ + Resizes an image with the specified resize_mode, width, and height. + + Args: + resize_mode: The mode to use when resizing the image. + 0: Resize the image to the specified width and height. + 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. + 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. + im: The image to resize. + width: The width to resize the image to. + height: The height to resize the image to. + upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img. + """ + + upscaler_name = upscaler_name or opts.upscaler_for_img2img + def resize(im, w, h): - if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': + if upscaler_name is None or upscaler_name == "None" or im.mode == 'L': return im.resize((w, h), resample=LANCZOS) scale = max(w / im.width, h / im.height) if scale > 1.0: - upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] - assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" + upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name] + assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}" upscaler = upscalers[0] im = upscaler.scaler.upscale(im, scale, upscaler.data_path) diff --git a/modules/processing.py b/modules/processing.py index 42dc19ea..4654570c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -658,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength - self.firstphase_width = firstphase_width - self.firstphase_height = firstphase_height - self.truncate_x = 0 - self.truncate_y = 0 + self.hr_scale = hr_scale + self.hr_upscaler = hr_upscaler + + if firstphase_width != 0 or firstphase_height != 0: + print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) + self.hr_scale = self.width / firstphase_width + self.width = firstphase_width + self.height = firstphase_height def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -674,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" - - if self.firstphase_width == 0 or self.firstphase_height == 0: - desired_pixel_count = 512 * 512 - actual_pixel_count = self.width * self.height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - self.firstphase_width = math.ceil(scale * self.width / 64) * 64 - self.firstphase_height = math.ceil(scale * self.height / 64) * 64 - firstphase_width_truncated = int(scale * self.width) - firstphase_height_truncated = int(scale * self.height) - - else: - - width_ratio = self.width / self.firstphase_width - height_ratio = self.height / self.firstphase_height - - if width_ratio > height_ratio: - firstphase_width_truncated = self.firstphase_width - firstphase_height_truncated = self.firstphase_width * self.height / self.width - else: - firstphase_width_truncated = self.firstphase_height * self.width / self.height - firstphase_height_truncated = self.firstphase_height - - self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f - self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f + self.extra_generation_params["Hires upscale"] = self.hr_scale + if self.hr_upscaler is not None: + self.extra_generation_params["Hires upscaler"] = self.hr_upscaler def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode + if self.enable_hr and latent_scale_mode is None: + assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) return samples - x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) - - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + target_width = int(self.width * self.hr_scale) + target_height = int(self.height * self.hr_scale) - """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" def save_intermediate(image, index): + """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" + if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: return @@ -723,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") - if opts.use_scale_latent_for_hires_fix: + if latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -747,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): save_intermediate(image, i) - image = images.resize_image(0, image, self.width, self.height) + image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) batch_images.append(image) @@ -764,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None diff --git a/modules/shared.py b/modules/shared.py index 7f430b93..b65559ee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -327,7 +327,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), - "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { @@ -545,6 +544,12 @@ opts = Options() if os.path.exists(config_filename): opts.load(config_filename) +latent_upscale_default_mode = "Latent" +latent_upscale_modes = { + "Latent": "bilinear", + "Latent (nearest)": "nearest", +} + sd_upscalers = [] sd_model = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 7f61e19a..e189a899 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: tiling=tiling, enable_hr=enable_hr, denoising_strength=denoising_strength if enable_hr else None, - firstphase_width=firstphase_width if enable_hr else None, - firstphase_height=firstphase_height if enable_hr else None, + hr_scale=hr_scale, + hr_upscaler=hr_upscaler, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 7070ea15..27cd9ddd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -684,11 +684,11 @@ def create_ui(): with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width") - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height") + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") with gr.Row(equal_height=True): @@ -729,8 +729,8 @@ def create_ui(): width, enable_hr, denoising_strength, - firstphase_width, - firstphase_height, + hr_scale, + hr_upscaler, ] + custom_inputs, outputs=[ @@ -762,7 +762,6 @@ def create_ui(): outputs=[hr_options], ) - txt2img_paste_fields = [ (txt2img_prompt, "Prompt"), (txt2img_negative_prompt, "Negative prompt"), @@ -781,8 +780,8 @@ def create_ui(): (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (firstphase_width, "First pass size-1"), - (firstphase_height, "First pass size-2"), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 3e0b2805..f92f9776 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -202,7 +202,7 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), AxisOption("VAE", str, apply_vae, format_value_add_label, None), AxisOption("Styles", str, apply_styles, format_value_add_label, None), @@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.hypernetwork = opts.sd_hypernetwork self.model = shared.sd_model - self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object): hypernetwork.apply_strength() opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -- cgit v1.2.3 From a1cf55a9d1c82f8e56c00d549bca5c8fa069f412 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 10:39:21 +0300 Subject: add option to reorder items in main UI --- modules/shared.py | 13 ++++++ modules/ui.py | 130 +++++++++++++++++++++++++++++++++++------------------- 2 files changed, 97 insertions(+), 46 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index b65559ee..23657a93 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -109,6 +109,17 @@ restricted_opts = { "outdir_save", } +ui_reorder_categories = [ + "sampler", + "dimensions", + "cfg", + "seed", + "checkboxes", + "hires_fix", + "batch", + "scripts", +] + cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ @@ -410,7 +421,9 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), + "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), + 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index 2c92c422..f2e7c0d6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -644,6 +644,13 @@ def create_sampler_and_steps_selection(choices, tabname): return steps, sampler_index +def ordered_ui_categories(): + user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + yield category + + def create_ui(): import modules.img2img import modules.txt2img @@ -672,32 +679,48 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='panel', elem_id="txt2img_settings"): - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - with FormRow(visible=False) as hr_options: - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + with FormRow(elem_id="txt2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + + elif category == "hires_fix": + with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options: + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) @@ -865,28 +888,43 @@ def create_ui(): with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "cfg": + with FormGroup(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + with FormRow(elem_id="img2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) -- cgit v1.2.3 From bddebe09edeb6a18f2c06986d5658a7be3a563ea Mon Sep 17 00:00:00 2001 From: Shondoit Date: Tue, 3 Jan 2023 10:26:37 +0100 Subject: Save Optimizer next to TI embedding Also add check to load only .PT and .BIN files as embeddings. (since we add .optim files in the same directory) --- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 40 ++++++++++++++++++++------ 2 files changed, 33 insertions(+), 9 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 23657a93..c541d18c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -355,7 +355,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), - "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fd253477..16176e90 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -28,6 +28,7 @@ class Embedding: self.cached_checksum = None self.sd_checkpoint = None self.sd_checkpoint_name = None + self.optimizer_state_dict = None def save(self, filename): embedding_data = { @@ -41,6 +42,13 @@ class Embedding: torch.save(embedding_data, filename) + if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: + optimizer_saved_dict = { + 'hash': self.checksum(), + 'optimizer_state_dict': self.optimizer_state_dict, + } + torch.save(optimizer_saved_dict, filename + '.optim') + def checksum(self): if self.cached_checksum is not None: return self.cached_checksum @@ -95,9 +103,10 @@ class EmbeddingDatabase: self.expected_shape = self.get_expected_shape() def process_file(path, filename): - name = os.path.splitext(filename)[0] + name, ext = os.path.splitext(filename) + ext = ext.upper() - if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) @@ -105,8 +114,10 @@ class EmbeddingDatabase: else: data = extract_image_data_embed(embed_image) name = data.get('name', name) - else: + elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") + else: + return # textual inversion embeddings if 'string_to_param' in data: @@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ embedding.vec.requires_grad = True optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) + if shared.opts.save_optimizer_state: + optimizer_state_dict = None + if os.path.exists(filename + '.optim'): + optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu') + if embedding.checksum() == optimizer_saved_dict.get('hash', None): + optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) + + if optimizer_state_dict is not None: + optimizer.load_state_dict(optimizer_state_dict) + print("Loaded existing optimizer from checkpoint") + else: + print("No saved optimizer exists in checkpoint") + + scaler = torch.cuda.amp.GradScaler() batch_size = ds.batch_size @@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') - #if shared.opts.save_optimizer_state: - #embedding.optimizer_state_dict = optimizer.state_dict() - save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) + save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) embedding_yet_to_be_embedded = True write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { @@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}

""" filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') - save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) + save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: print(traceback.format_exc(), file=sys.stderr) pass @@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}
return embedding, filename -def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): +def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): old_embedding_name = embedding.name old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None @@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache if remove_cached_checksum: embedding.cached_checksum = None embedding.name = embedding_name + embedding.optimizer_state_dict = optimizer.state_dict() embedding.save(filename) except: embedding.sd_checkpoint = old_sd_checkpoint -- cgit v1.2.3 From aaa4c2aacbb6523077334093c81bd475d757f7a1 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Tue, 3 Jan 2023 09:45:16 -0500 Subject: add api logging --- modules/api/api.py | 24 +++++++++++++++++++++++- modules/shared.py | 1 + 2 files changed, 24 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/api/api.py b/modules/api/api.py index 9c670f00..53135470 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,11 +1,12 @@ import base64 import io import time +import datetime import uvicorn from threading import Lock from io import BytesIO from gradio.processing_utils import decode_base64_to_file -from fastapi import APIRouter, Depends, FastAPI, HTTPException +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest @@ -67,6 +68,26 @@ def encode_pil_to_base64(image): bytes_data = output_bytes.getvalue() return base64.b64encode(bytes_data) +def init_api_middleware(app: FastAPI): + @app.middleware("http") + async def log_and_time(req: Request, call_next): + ts = time.time() + res: Response = await call_next(req) + duration = str(round(time.time() - ts, 4)) + res.headers["X-Process-Time"] = duration + if shared.cmd_opts.api_log: + print('API {t} {code} {prot}/{ver} {method} {p} {cli} {duration}'.format( + t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code = res.status_code, + ver = req.scope.get('http_version', '0.0'), + cli = req.scope.get('client', ('0:0.0.0', 0))[0], + prot = req.scope.get('scheme', 'err'), + method = req.scope.get('method', 'err'), + p = req.scope.get('path', 'err'), + duration = duration, + )) + return res + class Api: def __init__(self, app: FastAPI, queue_lock: Lock): @@ -78,6 +99,7 @@ class Api: self.router = APIRouter() self.app = app + init_api_middleware(self.app) self.queue_lock = queue_lock self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) diff --git a/modules/shared.py b/modules/shared.py index 23657a93..2a03d716 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -82,6 +82,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -- cgit v1.2.3 From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:35:07 +0300 Subject: helpful error message when trying to load 2.0 without config failing to load model weights from settings won't break generation for currently loaded model anymore --- modules/errors.py | 25 +++++++++++++++++++++++-- modules/sd_models.py | 26 ++++++++++++++++++-------- modules/shared.py | 9 +++++++-- webui.py | 12 ++++++++++-- 4 files changed, 58 insertions(+), 14 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/errors.py b/modules/errors.py index 372dc51a..a668c014 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -2,9 +2,30 @@ import sys import traceback +def print_error_explanation(message): + lines = message.strip().split("\n") + max_len = max([len(x) for x in lines]) + + print('=' * max_len, file=sys.stderr) + for line in lines: + print(line, file=sys.stderr) + print('=' * max_len, file=sys.stderr) + + +def display(e: Exception, task): + print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + message = str(e) + if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: + print_error_explanation(""" +The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file. +See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this. + """) + + def run(code, task): try: code() except Exception as e: - print(f"{task}: {type(e).__name__}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + display(task, e) diff --git a/modules/sd_models.py b/modules/sd_models.py index b98b05fc..6846b74a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -278,6 +278,7 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.use_fp16 = False sd_model = instantiate_from_config(sd_config.model) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -336,10 +338,12 @@ def load_model(checkpoint_info=None): def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - + if not sd_model: sd_model = shared.sd_model + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return @@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) - - sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) - - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: - sd_model.to(devices.device) + try: + load_model_weights(sd_model, checkpoint_info) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) print("Weights loaded.") + return sd_model diff --git a/modules/shared.py b/modules/shared.py index 23657a93..7588c47b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading +from modules import localization, sd_vae, extensions, script_loading, errors from modules.paths import models_path, script_path, sd_path @@ -494,7 +494,12 @@ class Options: return False if self.data_labels[key].onchange is not None: - self.data_labels[key].onchange() + try: + self.data_labels[key].onchange() + except Exception as e: + errors.display(e, f"changing setting {key} to {value}") + setattr(self, key, oldval) + return False return True diff --git a/webui.py b/webui.py index c7d55a97..13375e71 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook +from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path @@ -61,7 +61,15 @@ def initialize(): modelloader.load_upscalers() modules.sd_vae.refresh_vae_list() - modules.sd_models.load_model() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) -- cgit v1.2.3 From 96cf15bedecbed97ef9b70b8413d543a9aee5adf Mon Sep 17 00:00:00 2001 From: MMaker Date: Wed, 4 Jan 2023 05:12:06 -0500 Subject: Add new latent upscale modes --- modules/shared.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 7588c47b..a10f69a9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -564,8 +564,11 @@ if os.path.exists(config_filename): latent_upscale_default_mode = "Latent" latent_upscale_modes = { - "Latent": "bilinear", - "Latent (nearest)": "nearest", + "Latent": {"mode": "bilinear", "antialias": False}, + "Latent (antialiased)": {"mode": "bilinear", "antialias": True}, + "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, + "Latent (bicubic, antialiased)": {"mode": "bicubic", "antialias": True}, + "Latent (nearest)": {"mode": "nearest", "antialias": False}, } sd_upscalers = [] -- cgit v1.2.3 From b2151b934fe0a3613570c6abd7615d3788fd1c8f Mon Sep 17 00:00:00 2001 From: MMaker Date: Wed, 4 Jan 2023 05:36:18 -0500 Subject: Rename bicubic antialiased option Comma was causing the the value in PNG info to be quoted, which causes the upscaler dropdown option to be blank when sending to UI --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index a10f69a9..c1b20081 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -567,7 +567,7 @@ latent_upscale_modes = { "Latent": {"mode": "bilinear", "antialias": False}, "Latent (antialiased)": {"mode": "bilinear", "antialias": True}, "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, - "Latent (bicubic, antialiased)": {"mode": "bicubic", "antialias": True}, + "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, "Latent (nearest)": {"mode": "nearest", "antialias": False}, } -- cgit v1.2.3 From 1cfd8aec4ae5a6ca1afd67b44cb4ef6dd14d8c34 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 16:05:42 +0300 Subject: make it possible to work with opts.show_progress_every_n_steps = -1 with medvram --- modules/shared.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 4fcc6edd..54a6ba23 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -214,12 +214,13 @@ class State: """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" def set_current_image(self): + if not parallel_processing_allowed: + return + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: self.do_set_current_image() def do_set_current_image(self): - if not parallel_processing_allowed: - return if self.current_latent is None: return @@ -231,6 +232,7 @@ class State: self.current_image_sampling_step = self.sampling_step + state = State() artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) -- cgit v1.2.3 From bc43293c640aef65df3136de9e5bd8b7e79eb3e0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 23:56:43 +0300 Subject: fix incorrect display/calculation for number of steps for hires fix in progress bars --- modules/processing.py | 9 ++++++--- modules/sd_samplers.py | 5 +++-- modules/shared.py | 4 +++- 3 files changed, 12 insertions(+), 6 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index 9cad05f2..f28e7212 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -685,10 +685,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: - if state.job_count == -1: - state.job_count = self.n_iter * 2 - else: + if not state.processing_has_refined_job_count: + if state.job_count == -1: + state.job_count = self.n_iter + + shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count) state.job_count = state.job_count * 2 + state.processing_has_refined_job_count = True if self.hr_resize_x == 0 and self.hr_resize_y == 0: self.extra_generation_params["Hires upscale"] = self.hr_scale diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index e904d860..3851a77f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -97,8 +97,9 @@ sampler_extra_params = { def setup_img2img_steps(p, steps=None): if opts.img2img_fix_steps or steps is not None: - steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 - t_enc = p.steps - 1 + requested_steps = (steps or p.steps) + steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 + t_enc = requested_steps - 1 else: steps = p.steps t_enc = int(min(p.denoising_strength, 0.999) * steps) diff --git a/modules/shared.py b/modules/shared.py index 54a6ba23..04c545ee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -153,6 +153,7 @@ class State: job = "" job_no = 0 job_count = 0 + processing_has_refined_job_count = False job_timestamp = '0' sampling_step = 0 sampling_steps = 0 @@ -194,6 +195,7 @@ class State: def begin(self): self.sampling_step = 0 self.job_count = -1 + self.processing_has_refined_job_count = False self.job_no = 0 self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") self.current_latent = None @@ -608,7 +610,7 @@ class TotalTQDM: return if self._tqdm is None: self.reset() - self._tqdm.total=new_total + self._tqdm.total = new_total def clear(self): if self._tqdm is not None: -- cgit v1.2.3 From eea8fc40e16664ddc8a9aec77206da704a35dde0 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 07:24:22 -0800 Subject: Add option to save ti settings to file. --- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 30 +++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index e0f44c6d..933cd738 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,6 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), + "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 71e07bcc..2bed2ecb 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,6 +1,7 @@ import os import sys import traceback +import inspect import torch import tqdm @@ -229,6 +230,28 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + checkpoint = sd_models.select_checkpoint() + model_name = checkpoint.model_name + model_hash = '[{}]'.format(checkpoint.hash) + + # Get a list of the argument names. + arg_names = inspect.getfullargspec(save_settings_to_file).args + + # Create a list of the argument names to include in the settings string. + names = arg_names[:16] # Include all arguments up until the preview-related ones. + if preview_from_txt2img: + names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True. + + # Build the settings string. + settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" + for name in names: + value = locals()[name] + settings_str += f"{name}: {value}\n" + + with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: + fout.write(settings_str + "\n\n") + def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" @@ -292,13 +315,13 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ if initial_step >= steps: shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ None if clip_grad: - clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed @@ -306,7 +329,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - + if shared.opts.save_train_settings_to_txt: + save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) -- cgit v1.2.3 From 19a81ac2871ec900fc8b7955bbc2554b6c5ac6b1 Mon Sep 17 00:00:00 2001 From: cat Date: Thu, 5 Jan 2023 20:17:39 +0500 Subject: hires-fix: add "nearest-exact" latent upscale mode. --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index e0f44c6d..b7a3ce5c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -576,6 +576,7 @@ latent_upscale_modes = { "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, "Latent (nearest)": {"mode": "nearest", "antialias": False}, + "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False}, } sd_upscalers = [] -- cgit v1.2.3 From b85c2b5cf4a6809bc871718cf4680d49c3e95e94 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 08:14:38 -0800 Subject: Clean up ti, add same behavior to hypernetwork. --- modules/hypernetworks/hypernetwork.py | 31 +++++++++++++++++++++++++- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 14 +++++++----- 3 files changed, 40 insertions(+), 7 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6a9b1398..d5985263 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -401,7 +401,33 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, hypernet.save(fn) shared.reload_hypernetworks() +# Note: textual_inversion.py has a nearly identical function of the same name. +def save_settings_to_file(initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + checkpoint = sd_models.select_checkpoint() + model_name = checkpoint.model_name + model_hash = '[{}]'.format(checkpoint.hash) + # Starting index of preview-related arguments. + border_index = 19 + + # Get a list of the argument names, excluding default argument. + sig = inspect.signature(save_settings_to_file) + arg_names = [p.name for p in sig.parameters.values() if p.default == p.empty] + + # Create a list of the argument names to include in the settings string. + names = arg_names[:border_index] # Include all arguments up until the preview-related ones. + + # Include preview-related arguments if applicable. + if preview_from_txt2img: + names.extend(arg_names[border_index:]) + + # Build the settings string. + settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" + for name in names: + value = locals()[name] + settings_str += f"{name}: {value}\n" + with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: + fout.write(settings_str + "\n\n") def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. @@ -457,7 +483,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) diff --git a/modules/shared.py b/modules/shared.py index 933cd738..10231a75 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), - "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."), + "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 2bed2ecb..68648550 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -230,18 +230,20 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +# Note: hypernetwork.py has a nearly identical function of the same name. def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): checkpoint = sd_models.select_checkpoint() model_name = checkpoint.model_name model_hash = '[{}]'.format(checkpoint.hash) - + # Starting index of preview-related arguments. + border_index = 16 # Get a list of the argument names. arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. - names = arg_names[:16] # Include all arguments up until the preview-related ones. + names = arg_names[:border_index] # Include all arguments up until the preview-related ones. if preview_from_txt2img: - names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True. + names.extend(arg_names[border_index:]) # Include all remaining arguments if `preview_from_txt2img` is True. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" @@ -329,8 +331,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - if shared.opts.save_train_settings_to_txt: - save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) -- cgit v1.2.3 From d782a95967c9eea753df3333cd1954b6ec73eba0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 27 Dec 2022 08:50:55 -0500 Subject: Add Birch-san's sub-quadratic attention implementation --- README.md | 1 + modules/sd_hijack.py | 15 ++- modules/sd_hijack_optimizations.py | 124 ++++++++++++++++++----- modules/shared.py | 4 + modules/sub_quadratic_attention.py | 201 +++++++++++++++++++++++++++++++++++++ requirements.txt | 2 +- 6 files changed, 312 insertions(+), 35 deletions(-) create mode 100644 modules/sub_quadratic_attention.py (limited to 'modules/shared.py') diff --git a/README.md b/README.md index 556000fb..1913caf3 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ The documentation was moved from this README over to the project's [wiki](https: - Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) +- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san), Amin Rezaei (https://github.com/AminRezaei0x443) - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 690a9ec2..019a6f3f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet -from modules.sd_hijack_optimizations import invokeAI_mps_available - import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.openaimodel @@ -40,17 +38,16 @@ def apply_optimizations(): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + elif cmd_opts.opt_sub_quad_attention: + print("Applying sub-quadratic cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): - if not invokeAI_mps_available and shared.device.type == 'mps': - print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") - print("Applying v1 cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - else: - print("Applying cross attention optimization (InvokeAI).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 02c87f40..f5c153e8 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,7 @@ import math import sys import traceback -import importlib +import psutil import torch from torch import einsum @@ -12,6 +12,8 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork +from .sub_quadratic_attention import efficient_dot_product_attention + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: print(traceback.format_exc(), file=sys.stderr) +def get_available_vram(): + if shared.device.type == 'cuda': + stats = torch.cuda.memory_stats(shared.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total + else: + return psutil.virtual_memory().available + + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads @@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() @@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def check_for_psutil(): - try: - spec = importlib.util.find_spec('psutil') - return spec is not None - except ModuleNotFoundError: - return False - -invokeAI_mps_available = check_for_psutil() - # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- -if invokeAI_mps_available: - import psutil - mem_total_gb = psutil.virtual_memory().total // (1 << 30) +mem_total_gb = psutil.virtual_memory().total // (1 << 30) def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) @@ -215,6 +214,70 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # -- End of code from https://github.com/invoke-ai/InvokeAI -- + +# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +def sub_quad_attention_forward(self, x, context=None, mask=None): + assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) + + out_proj, dropout = self.to_out + x = out_proj(x) + x = dropout(x) + + return x + +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): + bytes_per_token = torch.finfo(q.dtype).bits//8 + batch_x_heads, q_tokens, _ = q.shape + _, k_tokens, _ = k.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + + if chunk_threshold_bytes is None: + chunk_threshold_bytes = available_vram + elif chunk_threshold_bytes == 0: + chunk_threshold_bytes = None + + if kv_chunk_size_min is None: + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) + elif kv_chunk_size_min == 0: + kv_chunk_size_min = None + + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -252,12 +315,7 @@ def cross_attention_attnblock_forward(self, x): h_ = torch.zeros_like(k, device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 @@ -312,3 +370,19 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + +def sub_quad_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..487a7792 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") +parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") +parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) +parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) +parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py new file mode 100644 index 00000000..b11dc1c7 --- /dev/null +++ b/modules/sub_quadratic_attention.py @@ -0,0 +1,201 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py +# license: +# unspecified +# credit: +# Amin Rezaei (original author) +# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 + +from functools import partial +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import math +from typing import Optional, NamedTuple, Protocol, List + +def dynamic_slice( + x: Tensor, + starts: List[int], + sizes: List[int], +) -> Tensor: + slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] + return x[slicing] + +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor + +class SummarizeChunk(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> AttnChunk: ... + +class ComputeQueryChunkAttn(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> Tensor: ... + +def _summarize_chunk( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> AttnChunk: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + exp_weights = torch.exp(attn_weights - max_score) + exp_values = torch.bmm(exp_weights, value) + max_score = max_score.squeeze(-1) + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + +def _query_chunk_attention( + query: Tensor, + key: Tensor, + value: Tensor, + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, +) -> Tensor: + batch_x_heads, k_tokens, k_channels_per_head = key.shape + _, _, v_channels_per_head = value.shape + + def chunk_scanner(chunk_idx: int) -> AttnChunk: + key_chunk = dynamic_slice( + key, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, k_channels_per_head) + ) + value_chunk = dynamic_slice( + value, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, v_channels_per_head) + ) + return summarize_chunk(query, key_chunk, value_chunk) + + chunks: List[AttnChunk] = [ + chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + ] + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + chunk_values, chunk_weights, chunk_max = acc_chunk + + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights + +# TODO: refactor CrossAttention#get_attention_scores to share code with this +def _get_attention_scores_no_kv_chunking( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> Tensor: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + hidden_states_slice = torch.bmm(attn_probs, value) + return hidden_states_slice + +class ScannedChunk(NamedTuple): + chunk_idx: int + attn_chunk: AttnChunk + +def efficient_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + query_chunk_size=1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + use_checkpoint=True, +): + """Computes efficient dot-product attention given query, key, and value. + This is efficient version of attention presented in + https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. + Args: + query: queries for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + key: keys for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + value: values to be used in attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + query_chunk_size: int: query chunks size + kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) + Returns: + Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. + """ + batch_x_heads, q_tokens, q_channels_per_head = query.shape + _, k_tokens, _ = key.shape + scale = q_channels_per_head ** -0.5 + + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + + def get_query_chunk(chunk_idx: int) -> Tensor: + return dynamic_slice( + query, + (0, chunk_idx, 0), + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + ) + + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) + summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + compute_query_chunk_attn: ComputeQueryChunkAttn = partial( + _get_attention_scores_no_kv_chunking, + scale=scale + ) if k_tokens <= kv_chunk_size else ( + # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) + partial( + _query_chunk_attention, + kv_chunk_size=kv_chunk_size, + summarize_chunk=summarize_chunk, + ) + ) + + if q_tokens <= query_chunk_size: + # fast-path for when there's just 1 query chunk + return compute_query_chunk_attn( + query=query, + key=key, + value=value, + ) + + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, + # and pass slices to be mutated, instead of torch.cat()ing the returned slices + res = torch.cat([ + compute_query_chunk_attn( + query=get_query_chunk(i * query_chunk_size), + key=key, + value=value, + ) for i in range(math.ceil(q_tokens / query_chunk_size)) + ], dim=1) + return res diff --git a/requirements.txt b/requirements.txt index 5bed694e..0dbea322 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,4 @@ inflection GitPython torchsde safetensors -psutil; sys_platform == 'darwin' +psutil -- cgit v1.2.3 From 683287d87f6401083a8d63eedc00ca7410214ca1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 08:52:06 +0300 Subject: rework saving training params to file #6372 --- modules/hypernetworks/hypernetwork.py | 28 +++++++------------------- modules/shared.py | 2 +- modules/textual_inversion/logging.py | 24 ++++++++++++++++++++++ modules/textual_inversion/textual_inversion.py | 23 +++------------------ 4 files changed, 35 insertions(+), 42 deletions(-) create mode 100644 modules/textual_inversion/logging.py (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3237c37a..b0cfbe71 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -13,7 +13,7 @@ import tqdm from einops import rearrange, repeat from ldm.util import default from modules import devices, processing, sd_models, shared, sd_samplers -from modules.textual_inversion import textual_inversion +from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_ @@ -401,25 +401,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, hypernet.save(fn) shared.reload_hypernetworks() -# Note: textual_inversion.py has a nearly identical function of the same name. -def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # Starting index of preview-related arguments. - border_index = 21 - # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - # Create a list of the argument names to include in the settings string. - names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. - # Build the settings string. - settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" - for name in names: - if name != 'log_directory': # It's useless and redundant to save log_directory. - value = locals()[name] - settings_str += f"{name}: {value}\n" - # Create or append to the file. - with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: - fout.write(settings_str + "\n\n") + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. @@ -477,7 +459,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + saved_params = dict( + model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), + **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]} + ) + logging.save_settings_to_file(log_directory, {**saved_params, **locals()}) latent_sampling_method = ds.latent_sampling_method diff --git a/modules/shared.py b/modules/shared.py index f0e10b35..57e489d0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), - "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."), + "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py new file mode 100644 index 00000000..8b1981d5 --- /dev/null +++ b/modules/textual_inversion/logging.py @@ -0,0 +1,24 @@ +import datetime +import json +import os + +saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} +saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} +saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} +saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet +saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} + + +def save_settings_to_file(log_directory, all_params): + now = datetime.datetime.now() + params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} + + keys = saved_params_all + if all_params.get('preview_from_txt2img'): + keys = keys | saved_params_previews + + params.update({k: v for k, v in all_params.items() if k in keys}) + + filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' + with open(os.path.join(log_directory, filename), "w") as file: + json.dump(params, file, indent=4) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e9cf432f..f9f5e8cd 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay) +from modules.textual_inversion.logging import save_settings_to_file + class Embedding: def __init__(self, vec, name, step=None): @@ -231,25 +233,6 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) -# Note: hypernetwork.py has a nearly identical function of the same name. -def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # Starting index of preview-related arguments. - border_index = 18 - # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - # Create a list of the argument names to include in the settings string. - names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. - # Build the settings string. - settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" - for name in names: - if name != 'log_directory': # It's useless and redundant to save log_directory. - value = locals()[name] - settings_str += f"{name}: {value}\n" - # Create or append to the file. - with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: - fout.write(settings_str + "\n\n") def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" @@ -330,7 +313,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) latent_sampling_method = ds.latent_sampling_method -- cgit v1.2.3 From b95a4c0ce5ab9c414e0494193bfff665f45e9e65 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 01:01:51 -0500 Subject: Change sub-quad chunk threshold to use percentage --- modules/sd_hijack_optimizations.py | 18 +++++++++--------- modules/shared.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f5c153e8..b416e9ac 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -233,7 +233,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) @@ -243,20 +243,20 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x -def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape _, k_tokens, _ = k.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) - - if chunk_threshold_bytes is None: - chunk_threshold_bytes = available_vram - elif chunk_threshold_bytes == 0: + if chunk_threshold is None: + chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + elif chunk_threshold == 0: chunk_threshold_bytes = None + else: + chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) - if kv_chunk_size_min is None: + if kv_chunk_size_min is None and chunk_threshold_bytes is not None: kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) elif kv_chunk_size_min == 0: kv_chunk_size_min = None @@ -382,7 +382,7 @@ def sub_quad_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index cb1dc312..d7a81db1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -59,7 +59,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) -parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None) +parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -- cgit v1.2.3 From 50194de93ffc9db763d9b08fcc9c3bde1aa86151 Mon Sep 17 00:00:00 2001 From: Kuma <36082288+KumiIT@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:12:45 +0100 Subject: typo UI fixes #6391 --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 57e489d0..865c3c07 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -430,7 +430,7 @@ options_templates.update(options_section(('ui', "User interface"), { "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), - 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"), + 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) -- cgit v1.2.3 From d4fd2418efb0986a8226add0b800fb5c73ffb58c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 14:57:47 +0300 Subject: add an option to use old hiresfix width/height behavior add a visual effect to inactive hires fix elements --- javascript/hires_fix.js | 25 +++++++++++++++++++++++++ modules/generation_parameters_copypaste.py | 17 +++++++++++------ modules/processing.py | 26 ++++++++++++++++++++++++-- modules/shared.py | 1 + modules/ui.py | 23 ++++++++++++++--------- style.css | 4 ++++ 6 files changed, 79 insertions(+), 17 deletions(-) create mode 100644 javascript/hires_fix.js (limited to 'modules/shared.py') diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js new file mode 100644 index 00000000..07fba549 --- /dev/null +++ b/javascript/hires_fix.js @@ -0,0 +1,25 @@ + +function setInactive(elem, inactive){ + console.log(elem) + if(inactive){ + elem.classList.add('inactive') + } else{ + elem.classList.remove('inactive') + } +} + +function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ + console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) + + hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') + hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') + hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') + + gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" + + setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0) + setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0) + setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0) + + return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y] +} diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 12a9de3d..f7f68b67 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -197,6 +197,15 @@ def restore_old_hires_fix_params(res): firstpass_width = res.get('First pass size-1', None) firstpass_height = res.get('First pass size-2', None) + if shared.opts.use_old_hires_fix_width_height: + hires_width = int(res.get("Hires resize-1", None)) + hires_height = int(res.get("Hires resize-2", None)) + + if hires_width is not None and hires_height is not None: + res['Size-1'] = hires_width + res['Size-2'] = hires_height + return + if firstpass_width is None or firstpass_height is None: return @@ -205,12 +214,8 @@ def restore_old_hires_fix_params(res): height = int(res.get("Size-2", 512)) if firstpass_width == 0 or firstpass_height == 0: - # old algorithm for auto-calculating first pass size - desired_pixel_count = 512 * 512 - actual_pixel_count = width * height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - firstpass_width = math.ceil(scale * width / 64) * 64 - firstpass_height = math.ceil(scale * height / 64) * 64 + from modules import processing + firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) res['Size-1'] = firstpass_width res['Size-2'] = firstpass_height diff --git a/modules/processing.py b/modules/processing.py index 1d23b15f..f04a0e1e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -687,6 +687,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: return res +def old_hires_fix_first_pass_dimensions(width, height): + """old algorithm for auto-calculating first pass size""" + + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + width = math.ceil(scale * width / 64) * 64 + height = math.ceil(scale * height / 64) * 64 + + return width, height + + class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None @@ -703,16 +715,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_upscale_to_y = hr_resize_y if firstphase_width != 0 or firstphase_height != 0: - print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) - self.hr_scale = self.width / firstphase_width + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height self.width = firstphase_width self.height = firstphase_height self.truncate_x = 0 self.truncate_y = 0 + self.applied_old_hires_behavior_to = None def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): + self.hr_resize_x = self.width + self.hr_resize_y = self.height + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height + + self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height) + self.applied_old_hires_behavior_to = (self.width, self.height) + if self.hr_resize_x == 0 and self.hr_resize_y == 0: self.extra_generation_params["Hires upscale"] = self.hr_scale self.hr_upscale_to_x = int(self.width * self.hr_scale) diff --git a/modules/shared.py b/modules/shared.py index a6712dae..a1e10201 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -398,6 +398,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), + "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { diff --git a/modules/ui.py b/modules/ui.py index 99483130..719c26b3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz with devices.autocast(): p.init([""], [0], [0]) - return f"resize: from {width}x{height} to {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" def apply_styles(prompt, prompt_neg, style1_name, style2_name): @@ -745,15 +745,20 @@ def create_ui(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - hr_resolution_preview_args = dict( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False - ) - for input in hr_resolution_preview_inputs: - input.change(**hr_resolution_preview_args) + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) diff --git a/style.css b/style.css index d796cbe9..ec5e4182 100644 --- a/style.css +++ b/style.css @@ -670,6 +670,10 @@ footer { min-width: auto; } +.inactive{ + opacity: 0.5; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 1fbb6f9ebe48326a3b12ecf611105dbc4a46891e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 23:35:40 +0300 Subject: make a dropdown for prompt template selection --- modules/hypernetworks/hypernetwork.py | 7 ++++-- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 35 ++++++++++++++++++++------ modules/ui.py | 11 ++++++-- webui.py | 3 +++ 5 files changed, 45 insertions(+), 12 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 32c67ccc..ea3f1db9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,6 +24,7 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} + class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { @@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = textual_inversion.textual_inversion_templates.get(template_filename, None) + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() diff --git a/modules/shared.py b/modules/shared.py index a1e10201..aa37c8ce 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 14be2c96..5420903f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -2,6 +2,7 @@ import os import sys import traceback import inspect +from collections import namedtuple import torch import tqdm @@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler -from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, - insert_image_data_embed, extract_image_data_embed, - caption_image_overlay) +from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay from modules.textual_inversion.logging import save_settings_to_file +TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) +textual_inversion_templates = {} + + +def list_textual_inversion_templates(): + textual_inversion_templates.clear() + + for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): + for fn in fns: + path = os.path.join(root, fn) + + textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) + + return textual_inversion_templates + + class Embedding: def __init__(self, vec, name, step=None): self.vec = vec @@ -274,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) -def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" @@ -284,8 +299,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert template_file, "Prompt template file is empty" - assert os.path.isfile(template_file), "Prompt template file doesn't exist" + assert template_filename, "Prompt template file not selected" + assert template_file, f"Prompt template file {template_filename} not found" + assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" assert steps > 0, "Max steps must be positive" @@ -296,10 +312,13 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 - validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = textual_inversion_templates.get(template_filename, None) + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = template_file.path shared.state.job = "train-embedding" shared.state.textinfo = "Initializing textual inversion training..." diff --git a/modules/ui.py b/modules/ui.py index ddfe1b1a..b6079aec 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -37,7 +37,7 @@ from modules import prompt_parser from modules.images import save_image from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img -import modules.textual_inversion.ui +from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text @@ -1322,6 +1322,9 @@ def create_ui(): outputs=[process_focal_crop_row], ) + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") with FormRow(): @@ -1345,7 +1348,11 @@ def create_ui(): dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") diff --git a/webui.py b/webui.py index 8737e593..47d372c7 100644 --- a/webui.py +++ b/webui.py @@ -33,6 +33,7 @@ import modules.sd_models import modules.sd_vae import modules.txt2img import modules.script_callbacks +import modules.textual_inversion.textual_inversion import modules.ui from modules import modelloader @@ -67,6 +68,8 @@ def initialize(): modules.sd_vae.refresh_vae_list() + modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + try: modules.sd_models.load_model() except Exception as e: -- cgit v1.2.3 From 29fb5327640465fc83111e2170c5d8aa2b15266c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 23:47:02 +0300 Subject: change color selector in settings to be part of form --- modules/shared.py | 4 ++-- modules/ui_components.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index aa37c8ce..264264a6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors +from modules import localization, sd_vae, extensions, script_loading, errors, ui_components from modules.paths import models_path, script_path, sd_path @@ -387,7 +387,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), - "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}), + "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), diff --git a/modules/ui_components.py b/modules/ui_components.py index cac001dc..97acff06 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -31,3 +31,9 @@ class FormHTML(gr.HTML, gr.components.FormComponent): def get_block_name(self): return "html" + +class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): + """Same as gr.ColorPicker but fits inside gradio forms""" + + def get_block_name(self): + return "colorpicker" -- cgit v1.2.3 From 0b8911d883118daa54f7735c5b753b5575d9f943 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 20:33:24 +0300 Subject: img2img UI rework: obsolete --gradio-img2img-tool --gradio-inpaint-tool and always show all tools each in own tab --- modules/img2img.py | 58 ++++++++++++++---------------- modules/shared.py | 4 +-- modules/ui.py | 103 +++++++++++++++++++++++++++-------------------------- style.css | 4 ++- 4 files changed, 84 insertions(+), 85 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/img2img.py b/modules/img2img.py index ca58b5d8..f62783c6 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,38 +59,34 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): - is_inpaint = mode == 1 - is_batch = mode == 2 - - if is_inpaint: - # Drawn mask - if mask_mode == 0: - is_mask_sketch = isinstance(init_img_with_mask, dict) - is_mask_paint = not is_mask_sketch - if is_mask_sketch: - # Sketch: mask iff. not transparent - image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') - else: - # Color-sketch: mask iff. painted over - image = init_img_with_mask - orig = init_img_with_mask_orig or init_img_with_mask - pred = np.any(np.array(image) != np.array(orig), axis=-1) - mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") - mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) - blur = ImageFilter.GaussianBlur(mask_blur) - image = Image.composite(image.filter(blur), orig, mask.filter(blur)) - - image = image.convert("RGB") - # Uploaded mask - else: - image = init_img_inpaint - mask = init_mask_inpaint - # No mask +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): + is_batch = mode == 5 + + if mode == 0: # img2img + image = init_img.convert("RGB") + mask = None + elif mode == 1: # img2img sketch + image = sketch.convert("RGB") + mask = None + elif mode == 2: # inpaint + image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] + alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') + mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') + image = image.convert("RGB") + elif mode == 3: # inpaint sketch + image = inpaint_color_sketch + orig = inpaint_color_sketch_orig or inpaint_color_sketch + pred = np.any(np.array(image) != np.array(orig), axis=-1) + mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") + mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) + blur = ImageFilter.GaussianBlur(mask_blur) + image = Image.composite(image.filter(blur), orig, mask.filter(blur)) + image = image.convert("RGB") + elif mode == 4: # inpaint upload mask + image = init_img_inpaint + mask = init_mask_inpaint else: - image = init_img + image = None mask = None # Use the EXIF orientation of photos taken by smartphones. diff --git a/modules/shared.py b/modules/shared.py index 264264a6..1c964237 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,8 +74,8 @@ parser.add_argument("--freeze-settings", action='store_true', help="disable edit parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") -parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it") +parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') +parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) diff --git a/modules/ui.py b/modules/ui.py index 82f5dd7c..e86a624b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -795,53 +795,67 @@ def create_ui(): with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): + with gr.Tabs(elem_id="mode_img2img"): + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480) - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) + inpaint_color_sketch_orig = gr.State(None) - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask") - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") @@ -900,20 +914,6 @@ def create_ui(): ] ) - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", @@ -924,11 +924,12 @@ def create_ui(): img2img_prompt_style, img2img_prompt_style2, init_img, + sketch, init_img_with_mask, - init_img_with_mask_orig, + inpaint_color_sketch, + inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, - mask_mode, steps, sampler_index, mask_blur, diff --git a/style.css b/style.css index ec5e4182..ffd6307f 100644 --- a/style.css +++ b/style.css @@ -557,7 +557,9 @@ canvas[key="mask"] { } #img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img, -img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img +#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img, +#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img, +#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img { height: 480px !important; max-height: 480px !important; -- cgit v1.2.3 From a95f1353089bdeaccd7c266b40cdd79efedfe632 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 09:56:59 +0300 Subject: change hash to sha256 --- .gitignore | 1 + modules/api/api.py | 2 +- modules/api/models.py | 3 +- modules/hashes.py | 72 +++++++++++++++ modules/hypernetworks/hypernetwork.py | 4 +- modules/sd_models.py | 116 ++++++++++++++++--------- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 6 +- webui.py | 2 + 9 files changed, 158 insertions(+), 50 deletions(-) create mode 100644 modules/hashes.py (limited to 'modules/shared.py') diff --git a/.gitignore b/.gitignore index 21fa26a7..0b1d17ca 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ notification.mp3 /extensions /test/stdout.txt /test/stderr.txt +/cache.json diff --git a/modules/api/api.py b/modules/api/api.py index 5767ba90..9814bbc2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -371,7 +371,7 @@ class Api: return upscalers def get_sd_models(self): - return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/api/models.py b/modules/api/models.py index c78095ca..1eb1fcf1 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -224,7 +224,8 @@ class UpscalerItem(BaseModel): class SDModelItem(BaseModel): title: str = Field(title="Title") model_name: str = Field(title="Model Name") - hash: str = Field(title="Hash") + hash: Optional[str] = Field(title="Short hash") + sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") config: str = Field(title="Config file") diff --git a/modules/hashes.py b/modules/hashes.py new file mode 100644 index 00000000..ebfbd90c --- /dev/null +++ b/modules/hashes.py @@ -0,0 +1,72 @@ +import hashlib +import json +import os.path + +import filelock + + +cache_filename = "cache.json" +cache_data = None + + +def dump_cache(): + with filelock.FileLock(cache_filename+".lock"): + with open(cache_filename, "w", encoding="utf8") as file: + json.dump(cache_data, file, indent=4) + + +def cache(subsection): + global cache_data + + if cache_data is None: + with filelock.FileLock(cache_filename+".lock"): + if not os.path.isfile(cache_filename): + cache_data = {} + else: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + + s = cache_data.get(subsection, {}) + cache_data[subsection] = s + + return s + + +def calculate_sha256(filename): + hash_sha256 = hashlib.sha256() + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def sha256(filename, title): + hashes = cache("hashes") + ondisk_mtime = os.path.getmtime(filename) + + if title in hashes: + cached_sha256 = hashes[title].get("sha256", None) + cached_mtime = hashes[title].get("mtime", 0) + + if ondisk_mtime <= cached_mtime and cached_sha256 is not None: + return cached_sha256 + + print(f"Calculating sha256 for {filename}: ", end='') + sha256_value = calculate_sha256(filename) + print(f"{sha256_value}") + + hashes[title] = { + "mtime": ondisk_mtime, + "sha256": sha256_value, + } + + dump_cache() + + return sha256_value + + + + + diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 83cbb4f0..9b5f2e79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -509,7 +509,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if shared.opts.save_training_settings_to_txt: saved_params = dict( - model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), + model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]} ) logging.save_settings_to_file(log_directory, {**saved_params, **locals()}) @@ -737,7 +737,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None try: - hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint = checkpoint.shorthash hypernetwork.sd_checkpoint_name = checkpoint.model_name hypernetwork.name = hypernetwork_name hypernetwork.save(filename) diff --git a/modules/sd_models.py b/modules/sd_models.py index c466f273..7babb9ae 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,17 +14,56 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} +checkpoint_alisases = {} checkpoints_loaded = collections.OrderedDict() + +class CheckpointInfo: + def __init__(self, filename): + self.filename = filename + abspath = os.path.abspath(filename) + + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + elif abspath.startswith(model_path): + name = abspath.replace(model_path, '') + else: + name = os.path.basename(filename) + + if name.startswith("\\") or name.startswith("/"): + name = name[1:] + + self.title = name + self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + self.hash = model_hash(filename) + self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + self.shorthash = None + self.sha256 = None + + def register(self): + checkpoints_list[self.title] = self + for id in self.ids: + checkpoint_alisases[id] = self + + def calculate_shorthash(self): + self.sha256 = hashes.sha256(self.filename, self.title) + self.shorthash = self.sha256[0:10] + + if self.shorthash not in self.ids: + self.ids += [self.shorthash, self.sha256] + self.register() + + return self.shorthash + + try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -43,10 +82,14 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - convert = lambda name: int(name) if name.isdigit() else name.lower() - alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] - return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def checkpoint_tiles(): + def convert(name): + return int(name) if name.isdigit() else name.lower() + + def alphanumeric_key(key): + return [convert(c) for c in re.split('([0-9]+)', key)] + + return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) def find_checkpoint_config(info): @@ -62,48 +105,38 @@ def find_checkpoint_config(info): def list_models(): checkpoints_list.clear() + checkpoint_alisases.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"]) - def modeltitle(path, shorthash): - abspath = os.path.abspath(path) - - if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): - name = abspath.replace(shared.cmd_opts.ckpt_dir, '') - elif abspath.startswith(model_path): - name = abspath.replace(model_path, '') - else: - name = os.path.basename(path) - - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - - shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] - - return f'{name} [{shorthash}]', shortname - cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): - h = model_hash(cmd_ckpt) - title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) - shared.opts.data['sd_model_checkpoint'] = title + checkpoint_info = CheckpointInfo(cmd_ckpt) + checkpoint_info.register() + + shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) + for filename in model_list: - h = model_hash(filename) - title, short_model_name = modeltitle(filename, h) + checkpoint_info = CheckpointInfo(filename) + checkpoint_info.register() + - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) +def get_closet_checkpoint_match(search_string): + checkpoint_info = checkpoint_alisases.get(search_string, None) + if checkpoint_info is not None: + return + found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) + if found: + return found[0] -def get_closet_checkpoint_match(searchString): - applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) - if len(applicable) > 0: - return applicable[0] return None def model_hash(filename): + """old hash that only looks at a small part of the file and is prone to collisions""" + try: with open(filename, "rb") as file: import hashlib @@ -119,7 +152,7 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - checkpoint_info = checkpoints_list.get(model_checkpoint, None) + checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info @@ -189,9 +222,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info, vae_file="auto"): - checkpoint_file = checkpoint_info.filename - sd_model_hash = checkpoint_info.hash +def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): + sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -201,9 +233,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") - sd = read_state_dict(checkpoint_file) + sd = read_state_dict(checkpoint_info.filename) model.load_state_dict(sd, strict=False) del sd @@ -235,14 +267,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash - model.sd_model_checkpoint = checkpoint_file + model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) sd_vae.load_vae(model, vae_file) diff --git a/modules/shared.py b/modules/shared.py index b90ded52..d74c069d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,7 +428,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), + "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6939efcc..63935878 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -407,7 +407,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) latent_sampling_method = ds.latent_sampling_method @@ -584,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) + footer_mid = '[{}]'.format(checkpoint.shorthash) footer_right = '{}v {}s'.format(vectorSize, steps_done) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) @@ -626,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None try: - embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint = checkpoint.shorthash embedding.sd_checkpoint_name = checkpoint.model_name if remove_cached_checksum: embedding.cached_checksum = None diff --git a/webui.py b/webui.py index 47d372c7..1fff80da 100644 --- a/webui.py +++ b/webui.py @@ -78,6 +78,8 @@ def initialize(): print("Stable diffusion model failed to load, exiting", file=sys.stderr) exit(1) + shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) -- cgit v1.2.3 From f9ac3352cb66ce2bc0aa4325130fc7267fb35e4f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 10:25:21 +0300 Subject: change hypernets to use sha256 hashes --- modules/hypernetworks/hypernetwork.py | 40 ++++++++++++++++++++--------------- modules/processing.py | 2 +- modules/sd_models.py | 2 +- modules/shared.py | 1 + 4 files changed, 26 insertions(+), 19 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 9b5f2e79..3aebefa8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers +from modules import devices, processing, sd_models, shared, sd_samplers, hashes from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -225,7 +225,7 @@ class Hypernetwork: torch.save(state_dict, filename) if shared.opts.save_optimizer_state and self.optimizer_state_dict: - optimizer_saved_dict['hash'] = sd_models.model_hash(filename) + optimizer_saved_dict['hash'] = self.shorthash() optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict torch.save(optimizer_saved_dict, filename + '.optim') @@ -237,32 +237,33 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - print(self.layer_structure) - optional_info = state_dict.get('optional_info', None) - if optional_info is not None: - print(f"INFO:\n {optional_info}\n") - self.optional_info = optional_info + self.optional_info = state_dict.get('optional_info', None) self.activation_func = state_dict.get('activation_func', None) - print(f"Activation function is {self.activation_func}") self.weight_init = state_dict.get('weight_initialization', 'Normal') - print(f"Weight initialization is {self.weight_init}") self.add_layer_norm = state_dict.get('is_layer_norm', False) - print(f"Layer norm is set to {self.add_layer_norm}") self.dropout_structure = state_dict.get('dropout_structure', None) self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False) - print(f"Dropout usage is set to {self.use_dropout}" ) self.activate_output = state_dict.get('activate_output', True) - print(f"Activate last layer is set to {self.activate_output}") self.last_layer_dropout = state_dict.get('last_layer_dropout', False) # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0. if self.dropout_structure is None: - print("Using previous dropout structure") self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) - print(f"Dropout structure is set to {self.dropout_structure}") - optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} + if shared.opts.print_hypernet_extra: + if self.optional_info is not None: + print(f" INFO:\n {self.optional_info}\n") - if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): + print(f" Layer structure: {self.layer_structure}") + print(f" Activation function: {self.activation_func}") + print(f" Weight initialization: {self.weight_init}") + print(f" Layer norm: {self.add_layer_norm}") + print(f" Dropout usage: {self.use_dropout}" ) + print(f" Activate last layer: {self.activate_output}") + print(f" Dropout structure: {self.dropout_structure}") + + optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {} + + if self.shorthash() == optimizer_saved_dict.get('hash', None): self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) else: self.optimizer_state_dict = None @@ -289,6 +290,11 @@ class Hypernetwork: self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) self.eval() + def shorthash(self): + sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}') + + return sha256[0:10] + def list_hypernetworks(path): res = {} @@ -296,7 +302,7 @@ def list_hypernetworks(path): name = os.path.splitext(os.path.basename(filename))[0] # Prevent a hypothetical "None.pt" from being listed. if name != "None": - res[name + f"({sd_models.model_hash(filename)})"] = filename + res[name] = filename return res diff --git a/modules/processing.py b/modules/processing.py index ae04cab7..849f6b19 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -437,7 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), - "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)), + "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()), "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), diff --git a/modules/sd_models.py b/modules/sd_models.py index 7babb9ae..8f00191c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -125,7 +125,7 @@ def list_models(): def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_alisases.get(search_string, None) if checkpoint_info is not None: - return + return checkpoint_info found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) if found: diff --git a/modules/shared.py b/modules/shared.py index d74c069d..a6c61db3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -361,6 +361,7 @@ options_templates.update(options_section(('system', "System"), { "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), + "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), })) options_templates.update(options_section(('training', "Training"), { -- cgit v1.2.3 From 08c6f009a5ee92dd3218a942c08e8337c26352be Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 15:55:40 +0300 Subject: load hashes from cache for checkpoints that have them add checkpoint hash to footer --- javascript/ui.js | 25 ++++++++++++++++--------- modules/hashes.py | 26 +++++++++++++++++++------- modules/sd_models.py | 9 ++++++--- modules/shared.py | 1 + modules/ui.py | 2 ++ script.js | 4 ++++ 6 files changed, 48 insertions(+), 19 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/ui.js b/javascript/ui.js index a41dd26f..1e04a8f4 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -143,14 +143,6 @@ function confirm_clear_prompt(prompt, negative_prompt) { opts = {} -function apply_settings(jsdata){ - console.log(jsdata) - - opts = JSON.parse(jsdata) - - return jsdata -} - onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -160,7 +152,7 @@ onUiUpdate(function(){ textarea = json_elem.querySelector('textarea') jsdata = textarea.value opts = JSON.parse(jsdata) - + executeCallbacks(optionsChangedCallbacks); Object.defineProperty(textarea, 'value', { set: function(newValue) { @@ -171,6 +163,8 @@ onUiUpdate(function(){ if (oldValue != newValue) { opts = JSON.parse(textarea.value) } + + executeCallbacks(optionsChangedCallbacks); }, get: function() { var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value'); @@ -201,6 +195,19 @@ onUiUpdate(function(){ } }) + +onOptionsChanged(function(){ + elem = gradioApp().getElementById('sd_checkpoint_hash') + sd_checkpoint_hash = opts.sd_checkpoint_hash || "" + shorthash = sd_checkpoint_hash.substr(0,10) + + if(elem && elem.textContent != shorthash){ + elem.textContent = shorthash + elem.title = sd_checkpoint_hash + elem.href = "https://google.com/search?q=" + sd_checkpoint_hash + } +}) + let txt2img_textarea, img2img_textarea = undefined; let wait_time = 800 let token_timeout; diff --git a/modules/hashes.py b/modules/hashes.py index ebfbd90c..14231771 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -42,23 +42,35 @@ def calculate_sha256(filename): return hash_sha256.hexdigest() -def sha256(filename, title): +def sha256_from_cache(filename, title): hashes = cache("hashes") ondisk_mtime = os.path.getmtime(filename) - if title in hashes: - cached_sha256 = hashes[title].get("sha256", None) - cached_mtime = hashes[title].get("mtime", 0) + if title not in hashes: + return None + + cached_sha256 = hashes[title].get("sha256", None) + cached_mtime = hashes[title].get("mtime", 0) + + if ondisk_mtime > cached_mtime or cached_sha256 is None: + return None + + return cached_sha256 + + +def sha256(filename, title): + hashes = cache("hashes") - if ondisk_mtime <= cached_mtime and cached_sha256 is not None: - return cached_sha256 + sha256_value = sha256_from_cache(filename, title) + if sha256_value is not None: + return sha256_value print(f"Calculating sha256 for {filename}: ", end='') sha256_value = calculate_sha256(filename) print(f"{sha256_value}") hashes[title] = { - "mtime": ondisk_mtime, + "mtime": os.path.getmtime(filename), "sha256": sha256_value, } diff --git a/modules/sd_models.py b/modules/sd_models.py index 1fe6d11b..e5a0bc63 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -44,9 +44,11 @@ class CheckpointInfo: self.title = name self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) - self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] - self.shorthash = None - self.sha256 = None + + self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title) + self.shorthash = self.sha256[0:10] if self.sha256 else None + + self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -269,6 +271,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info + shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 model.logvar = model.logvar.to(devices.device) # fix for training diff --git a/modules/shared.py b/modules/shared.py index a6c61db3..c9988d4d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -458,6 +458,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable those extensions"), + "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) options_templates.update() diff --git a/modules/ui.py b/modules/ui.py index e86a624b..2625ae32 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1841,4 +1841,6 @@ xformers: {xformers_version} gradio: {gr.__version__}  •  commit: {short_commit} + •  +checkpoint: N/A """ diff --git a/script.js b/script.js index 21960d91..3345e32b 100644 --- a/script.js +++ b/script.js @@ -14,6 +14,7 @@ function get_uiCurrentTabContent() { uiUpdateCallbacks = [] uiTabChangeCallbacks = [] +optionsChangedCallbacks = [] let uiCurrentTab = null function onUiUpdate(callback){ @@ -22,6 +23,9 @@ function onUiUpdate(callback){ function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } +function onOptionsChanged(callback){ + optionsChangedCallbacks.push(callback) +} function runCallback(x, m){ try { -- cgit v1.2.3 From f94a215abed85b34ae978853078812801d3e7738 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 16:29:23 +0300 Subject: add an option to choose what you want to see in live preview (Live preview subject) and moves live preview settings to its own tab --- modules/sd_samplers.py | 15 ++++++++++----- modules/shared.py | 13 +++++++++---- modules/ui_progress.py | 2 +- 3 files changed, 20 insertions(+), 10 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 01221b89..7616fded 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -138,7 +138,7 @@ def samples_to_image_grid(samples, approximation=None): def store_latent(decoded): state.current_latent = decoded - if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: + if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: shared.state.current_image = sample_to_image(decoded) @@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler: self.nmask = p.nmask if hasattr(p, 'nmask') else None def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) if valid_step == floor(valid_step): return int(valid_step) + 1 @@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler: if image_conditioning is not None: conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - + samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) return samples @@ -352,6 +351,11 @@ class CFGDenoiser(torch.nn.Module): x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) + if opts.live_preview_content == "Prompt": + store_latent(x_out[0:uncond.shape[0]]) + elif opts.live_preview_content == "Negative prompt": + store_latent(x_out[-uncond.shape[0]:]) + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) if self.mask is not None: @@ -423,7 +427,8 @@ class KDiffusionSampler: def callback_state(self, d): step = d['i'] latent = d["denoised"] - store_latent(latent) + if opts.live_preview_content == "Combined": + store_latent(latent) self.last_latent = latent if self.stop_at is not None and step > self.stop_at: diff --git a/modules/shared.py b/modules/shared.py index c9988d4d..e0ec3136 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -176,7 +176,7 @@ class State: self.interrupted = True def nextjob(self): - if opts.show_progress_every_n_steps == -1: + if opts.live_previews_enable and opts.show_progress_every_n_steps == -1: self.do_set_current_image() self.job_no += 1 @@ -224,7 +224,7 @@ class State: if not parallel_processing_allowed: return - if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable: self.do_set_current_image() def do_set_current_image(self): @@ -423,8 +423,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), - "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), @@ -444,6 +442,13 @@ options_templates.update(options_section(('ui', "User interface"), { 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) +options_templates.update(options_section(('ui', "Live previews"), { + "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), + "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), + "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), +})) + options_templates.update(options_section(('sampler-params', "Sampler parameters"), { "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), diff --git a/modules/ui_progress.py b/modules/ui_progress.py index 592fda55..7cd312e4 100644 --- a/modules/ui_progress.py +++ b/modules/ui_progress.py @@ -52,7 +52,7 @@ def check_progress_call(id_part): image = gr.update(visible=False) preview_visibility = gr.update(visible=False) - if opts.show_progress_every_n_steps != 0: + if opts.live_previews_enable: shared.state.set_current_image() image = shared.state.current_image -- cgit v1.2.3 From fad850fc3d33e7cda2ce4b3a32ab7976c313db53 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 14 Jan 2023 11:18:05 -0500 Subject: add server_start to shared.state --- modules/shared.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index e0ec3136..ef93637c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -168,6 +168,7 @@ class State: textinfo = None time_start = None need_restart = False + server_start = None def skip(self): self.skipped = True @@ -241,6 +242,7 @@ class State: state = State() +state.server_start = time.time() artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) -- cgit v1.2.3 From a5bbcd215304e0c83ab2b9fe7f172f88536d7629 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 19:56:09 +0300 Subject: fix bug with "Ignore selected VAE for..." option completely disabling VAE election rework VAE resolving code to be more simple --- modules/sd_models.py | 6 +- modules/sd_vae.py | 194 ++++++++++++++++++++------------------------------- modules/shared.py | 4 +- scripts/xy_grid.py | 27 ++++--- 4 files changed, 95 insertions(+), 136 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e5a0bc63..6a681cef 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -224,7 +224,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): +def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -277,8 +277,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) - sd_vae.load_vae(model, vae_file) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + sd_vae.load_vae(model, vae_file, vae_source) def enable_midas_autodownload(): diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0a49daa1..6ea92711 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -9,23 +9,9 @@ import glob from copy import deepcopy -model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) -vae_dir = "VAE" -vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) - - +vae_path = os.path.abspath(os.path.join(models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - - -default_vae_dict = {"auto": "auto", "None": None, None: None} -default_vae_list = ["auto", "None"] - - -default_vae_values = [default_vae_dict[x] for x in default_vae_list] -vae_dict = dict(default_vae_dict) -vae_list = list(default_vae_list) -first_load = True +vae_dict = {} base_vae = None @@ -64,100 +50,69 @@ def restore_base_vae(model): def get_filename(filepath): - return os.path.splitext(os.path.basename(filepath))[0] - - -def refresh_vae_list(vae_path=vae_path, model_path=model_path): - global vae_dict, vae_list - res = {} - candidates = [ - *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), + return os.path.basename(filepath) + + +def refresh_vae_list(): + vae_dict.clear() + + paths = [ + os.path.join(sd_models.model_path, '**/*.vae.ckpt'), + os.path.join(sd_models.model_path, '**/*.vae.pt'), + os.path.join(sd_models.model_path, '**/*.vae.safetensors'), + os.path.join(vae_path, '**/*.ckpt'), + os.path.join(vae_path, '**/*.pt'), + os.path.join(vae_path, '**/*.safetensors'), ] - if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): - candidates.append(shared.cmd_opts.vae_path) + + if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): + paths += [ + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), + ] + + candidates = [] + for path in paths: + candidates += glob.iglob(path, recursive=True) + for filepath in candidates: name = get_filename(filepath) - res[name] = filepath - vae_list.clear() - vae_list.extend(default_vae_list) - vae_list.extend(list(res.keys())) - vae_dict.clear() - vae_dict.update(res) - vae_dict.update(default_vae_dict) - return vae_list - - -def get_vae_from_settings(vae_file="auto"): - # else, we load from settings, if not set to be default - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print(f"Selected VAE doesn't exist: {vae_file}") - return vae_file - - -def resolve_vae(checkpoint_file=None, vae_file="auto"): - global first_load, vae_dict, vae_list - - # if vae_file argument is provided, it takes priority, but not saved - if vae_file and vae_file not in default_vae_list: - if not os.path.isfile(vae_file): - print(f"VAE provided as function argument doesn't exist: {vae_file}") - vae_file = "auto" - # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported - if first_load and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - shared.opts.data['sd_vae'] = get_filename(vae_file) - else: - print(f"VAE provided as command line argument doesn't exist: {vae_file}") - # fallback to selector in settings, if vae selector not set to act as default fallback - if not shared.opts.sd_vae_as_default: - vae_file = get_vae_from_settings(vae_file) - # vae-path cmd arg takes priority for auto - if vae_file == "auto" and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - print(f"Using VAE provided as command line argument: {vae_file}") - # if still not found, try look for ".vae.pt" beside model - model_path = os.path.splitext(checkpoint_file)[0] - if vae_file == "auto": - vae_file_try = model_path + ".vae.pt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.ckpt" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.ckpt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.safetensors" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.safetensors" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # No more fallbacks for auto - if vae_file == "auto": - vae_file = None - # Last check, just because - if vae_file and not os.path.exists(vae_file): - vae_file = None - - return vae_file - - -def load_vae(model, vae_file=None): - global first_load, vae_dict, vae_list, loaded_vae_file + vae_dict[name] = filepath + + +def find_vae_near_checkpoint(checkpoint_file): + checkpoint_path = os.path.splitext(checkpoint_file)[0] + for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + if os.path.isfile(vae_location): + return vae_location + + return None + + +def resolve_vae(checkpoint_file): + if shared.cmd_opts.vae_path is not None: + return shared.cmd_opts.vae_path, 'from commandline argument' + + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"): + return vae_near_checkpoint, 'found near the checkpoint' + + if shared.opts.sd_vae == "None": + return None, None + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return vae_from_options, 'specified in settings' + + if shared.opts.sd_vae != "Automatic": + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return None, None + + +def load_vae(model, vae_file=None, vae_source="from unknown source"): + global vae_dict, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -165,12 +120,12 @@ def load_vae(model, vae_file=None): if vae_file: if cache_enabled and vae_file in checkpoints_loaded: # use vae checkpoint cache - print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) _load_vae_dict(model, checkpoints_loaded[vae_file]) else: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") + assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" + print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) @@ -191,14 +146,12 @@ def load_vae(model, vae_file=None): vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file - vae_list.append(vae_opt) + elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file - first_load = False - # don't call this from outside def _load_vae_dict(model, vae_dict_1): @@ -211,7 +164,10 @@ def clear_loaded_vae(): loaded_vae_file = None -def reload_vae_weights(sd_model=None, vae_file="auto"): +unspecified = object() + + +def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack if not sd_model: @@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): checkpoint_info = sd_model.sd_checkpoint_info checkpoint_file = checkpoint_info.filename - vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if vae_file == unspecified: + vae_file, vae_source = resolve_vae(checkpoint_file) + else: + vae_source = "from function argument" if loaded_vae_file == vae_file: return @@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file) + load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) @@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("VAE Weights loaded.") + print("VAE weights loaded.") return sd_model diff --git a/modules/shared.py b/modules/shared.py index e0ec3136..9756adea 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) -parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) +parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) @@ -383,7 +383,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index f04d9b7e..bd3087d4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -125,24 +125,21 @@ def apply_upscale_latent_space(p, x, xs): def find_vae(name: str): - if name.lower() in ['auto', 'none']: - return name + if name.lower() in ['auto', 'automatic']: + return modules.sd_vae.unspecified + if name.lower() == 'none': + return None else: - vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE')) - found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True) - if found: - return found[0] + choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] + if len(choices) == 0: + print(f"No VAE found for {name}; using automatic") + return modules.sd_vae.unspecified else: - return 'auto' + return modules.sd_vae.vae_dict[choices[0]] def apply_vae(p, x, xs): - if x.lower().strip() == 'none': - modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None') - else: - found = find_vae(x) - if found: - v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found) + modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): @@ -271,7 +268,9 @@ class SharedSettingsStackHelper(object): def __exit__(self, exc_type, exc_value, tb): modules.sd_models.reload_model_weights(self.model) - modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae)) + + opts.data["sd_vae"] = self.vae + modules.sd_vae.reload_vae_weights(self.model) hypernetwork.load_hypernetwork(self.hypernetwork) hypernetwork.apply_strength() -- cgit v1.2.3 From 9ef41df6f9043d58fbbeea1f06be8e5c8622248b Mon Sep 17 00:00:00 2001 From: Josh R Date: Sat, 14 Jan 2023 15:26:45 -0800 Subject: add inpaint masking controls to orderable section that the settings can order --- modules/shared.py | 1 + modules/ui.py | 58 +++++++++++++++++++++++++++---------------------------- 2 files changed, 30 insertions(+), 29 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 51df056c..7ce8003f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -116,6 +116,7 @@ restricted_opts = { } ui_reorder_categories = [ + "masking", "sampler", "dimensions", "cfg", diff --git a/modules/ui.py b/modules/ui.py index 2425c66f..174930ab 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -867,35 +867,6 @@ def create_ui(): outputs=[], ) - with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - - with FormRow(): - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - def select_img2img_tab(tab): - return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), - - for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): - elem.select( - fn=lambda tab=i: select_img2img_tab(tab), - inputs=[], - outputs=[inpaint_controls, mask_alpha], - ) - - with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") @@ -937,6 +908,35 @@ def create_ui(): with FormGroup(elem_id="img2img_script_container"): custom_inputs = modules.scripts.scripts_img2img.setup_ui() + elif category == "masking": + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) -- cgit v1.2.3 From d8b90ac121cbf0c18b1dc9d56a5e1d14ca51e74e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 18:50:56 +0300 Subject: big rework of progressbar/preview system to allow multiple users to prompts at the same time and do not get previews of each other --- javascript/progressbar.js | 249 ++++++++++++++++--------- javascript/textualInversion.js | 13 +- javascript/ui.js | 33 +++- modules/call_queue.py | 19 +- modules/hypernetworks/hypernetwork.py | 6 +- modules/img2img.py | 2 +- modules/progress.py | 96 ++++++++++ modules/sd_samplers.py | 2 +- modules/shared.py | 16 +- modules/textual_inversion/preprocess.py | 2 +- modules/textual_inversion/textual_inversion.py | 6 +- modules/txt2img.py | 2 +- modules/ui.py | 41 ++-- modules/ui_progress.py | 101 ---------- style.css | 74 +++++--- webui.py | 3 + 16 files changed, 390 insertions(+), 275 deletions(-) create mode 100644 modules/progress.py delete mode 100644 modules/ui_progress.py (limited to 'modules/shared.py') diff --git a/javascript/progressbar.js b/javascript/progressbar.js index d6323ed9..b7524ef7 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,82 +1,25 @@ // code related to showing and updating progressbar shown as the image is being made -global_progressbars = {} -galleries = {} -galleryObservers = {} - -// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running -timeoutIds = {} -function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ - // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id - // every time you use gr.HTML(elem_id='xxx'), so we handle this here - var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar) - var progressbarParent - if(progressbar){ - progressbarParent = gradioApp().querySelector("#"+id_progressbar) - } else{ - progressbar = gradioApp().getElementById(id_progressbar) - progressbarParent = null - } - var skip = id_skip ? gradioApp().getElementById(id_skip) : null - var interrupt = gradioApp().getElementById(id_interrupt) - - if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ - if(progressbar.innerText){ - let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion'; - if(document.title != newtitle){ - document.title = newtitle; - } - }else{ - let newtitle = 'Stable Diffusion' - if(document.title != newtitle){ - document.title = newtitle; - } - } - } - - if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ - global_progressbars[id_progressbar] = progressbar - - var mutationObserver = new MutationObserver(function(m){ - if(timeoutIds[id_part]) return; - - preview = gradioApp().getElementById(id_preview) - gallery = gradioApp().getElementById(id_gallery) +galleries = {} +storedGallerySelections = {} +galleryObservers = {} - if(preview != null && gallery != null){ - preview.style.width = gallery.clientWidth + "px" - preview.style.height = gallery.clientHeight + "px" - if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px" +function rememberGallerySelection(id_gallery){ + storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery) +} - //only watch gallery if there is a generation process going on - check_gallery(id_gallery); +function getGallerySelectedIndex(id_gallery){ + let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') + let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') - var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; - if(progressDiv){ - timeoutIds[id_part] = window.setTimeout(function() { - timeoutIds[id_part] = null - requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) - }, 500) - } else{ - if (skip) { - skip.style.display = "none" - } - interrupt.style.display = "none" + let currentlySelectedIndex = -1 + galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } }) - //disconnect observer once generation finished, so user can close selected image if they want - if (galleryObservers[id_gallery]) { - galleryObservers[id_gallery].disconnect(); - galleries[id_gallery] = null; - } - } - } - - }); - mutationObserver.observe( progressbar, { childList:true, subtree:true }) - } + return currentlySelectedIndex } +// this is a workaround for https://github.com/gradio-app/gradio/issues/2984 function check_gallery(id_gallery){ let gallery = gradioApp().getElementById(id_gallery) // if gallery has no change, no need to setting up observer again. @@ -85,10 +28,16 @@ function check_gallery(id_gallery){ if(galleryObservers[id_gallery]){ galleryObservers[id_gallery].disconnect(); } - let prevSelectedIndex = selected_gallery_index(); + + storedGallerySelections[id_gallery] = -1 + galleryObservers[id_gallery] = new MutationObserver(function (){ let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') + let currentlySelectedIndex = getGallerySelectedIndex(id_gallery) + prevSelectedIndex = storedGallerySelections[id_gallery] + storedGallerySelections[id_gallery] = -1 + if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { // automatically re-open previously selected index (if exists) activeElement = gradioApp().activeElement; @@ -120,30 +69,150 @@ function check_gallery(id_gallery){ } onUiUpdate(function(){ - check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') - check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') - check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') + check_gallery('txt2img_gallery') + check_gallery('img2img_gallery') }) -function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ - btn = gradioApp().getElementById(id_part+"_check_progress"); - if(btn==null) return; - - btn.click(); - var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; - var skip = id_skip ? gradioApp().getElementById(id_skip) : null - var interrupt = gradioApp().getElementById(id_interrupt) - if(progressDiv && interrupt){ - if (skip) { - skip.style.display = "block" +function request(url, data, handler, errorHandler){ + var xhr = new XMLHttpRequest(); + var url = url; + xhr.open("POST", url, true); + xhr.setRequestHeader("Content-Type", "application/json"); + xhr.onreadystatechange = function () { + if (xhr.readyState === 4) { + if (xhr.status === 200) { + var js = JSON.parse(xhr.responseText); + handler(js) + } else{ + errorHandler() + } } - interrupt.style.display = "block" + }; + var js = JSON.stringify(data); + xhr.send(js); +} + +function pad2(x){ + return x<10 ? '0'+x : x +} + +function formatTime(secs){ + if(secs > 3600){ + return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60) + } else if(secs > 60){ + return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60) + } else{ + return Math.floor(secs) + "s" } } -function requestProgress(id_part){ - btn = gradioApp().getElementById(id_part+"_check_progress_initial"); - if(btn==null) return; +function randomId(){ + return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" +} + +// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and +// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd. +// calls onProgress every time there is a progress update +function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){ + var dateStart = new Date() + var wasEverActive = false + var parentProgressbar = progressbarContainer.parentNode + var parentGallery = gallery.parentNode + + var divProgress = document.createElement('div') + divProgress.className='progressDiv' + var divInner = document.createElement('div') + divInner.className='progress' + + divProgress.appendChild(divInner) + parentProgressbar.insertBefore(divProgress, progressbarContainer) + + var livePreview = document.createElement('div') + livePreview.className='livePreview' + parentGallery.insertBefore(livePreview, gallery) + + var removeProgressBar = function(){ + parentProgressbar.removeChild(divProgress) + parentGallery.removeChild(livePreview) + atEnd() + } + + var fun = function(id_task, id_live_preview){ + request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ + console.log(res) + + if(res.completed){ + removeProgressBar() + return + } + + var rect = progressbarContainer.getBoundingClientRect() + + if(rect.width){ + divProgress.style.width = rect.width + "px"; + } + + progressText = "" + + divInner.style.width = ((res.progress || 0) * 100.0) + '%' + + if(res.progress > 0){ + progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%' + } + + if(res.eta){ + progressText += " ETA: " + formatTime(res.eta) + } else if(res.textinfo){ + progressText += " " + res.textinfo + } + + divInner.textContent = progressText + + var elapsedFromStart = (new Date() - dateStart) / 1000 + + if(res.active) wasEverActive = true; + + if(! res.active && wasEverActive){ + removeProgressBar() + return + } + + if(elapsedFromStart > 5 && !res.queued && !res.active){ + removeProgressBar() + return + } + + + if(res.live_preview){ + var img = new Image(); + img.onload = function() { + var rect = gallery.getBoundingClientRect() + if(rect.width){ + livePreview.style.width = rect.width + "px" + livePreview.style.height = rect.height + "px" + } + + livePreview.innerHTML = '' + livePreview.appendChild(img) + if(livePreview.childElementCount > 2){ + livePreview.removeChild(livePreview.firstElementChild) + } + } + img.src = res.live_preview; + } + + + if(onProgress){ + onProgress(res) + } + + setTimeout(() => { + fun(id_task, res.id_live_preview); + }, 500) + }, function(){ + removeProgressBar() + }) + } - btn.click(); + fun(id_task, 0) } diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js index 8061be08..0354b860 100644 --- a/javascript/textualInversion.js +++ b/javascript/textualInversion.js @@ -1,8 +1,17 @@ + function start_training_textual_inversion(){ - requestProgress('ti') gradioApp().querySelector('#ti_error').innerHTML='' - return args_to_array(arguments) + var id = randomId() + requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){ + gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo + }) + + var res = args_to_array(arguments) + + res[0] = id + + return res } diff --git a/javascript/ui.js b/javascript/ui.js index f8279124..ecf97cb3 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -126,18 +126,41 @@ function create_submit_args(args){ return res } +function showSubmitButtons(tabname, show){ + gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block" + gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block" +} + function submit(){ - requestProgress('txt2img') + rememberGallerySelection('txt2img_gallery') + showSubmitButtons('txt2img', false) + + var id = randomId() + requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ + showSubmitButtons('txt2img', true) + + }) - return create_submit_args(arguments) + var res = create_submit_args(arguments) + + res[0] = id + + return res } function submit_img2img(){ - requestProgress('img2img') + rememberGallerySelection('img2img_gallery') + showSubmitButtons('img2img', false) + + var id = randomId() + requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ + showSubmitButtons('img2img', true) + }) - res = create_submit_args(arguments) + var res = create_submit_args(arguments) - res[0] = get_tab_index('mode_img2img') + res[0] = id + res[1] = get_tab_index('mode_img2img') return res } diff --git a/modules/call_queue.py b/modules/call_queue.py index 4cd49533..92097c15 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -4,7 +4,7 @@ import threading import traceback import time -from modules import shared +from modules import shared, progress queue_lock = threading.Lock() @@ -22,12 +22,23 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): - shared.state.begin() + # if the first argument is a string that says "task(...)", it is treated as a job id + if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": + id_task = args[0] + progress.add_task_to_queue(id_task) + else: + id_task = None with queue_lock: - res = func(*args, **kwargs) + shared.state.begin() + progress.start_task(id_task) + + try: + res = func(*args, **kwargs) + finally: + progress.finish_task(id_task) - shared.state.end() + shared.state.end() return res diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3aebefa8..ae6af516 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -629,7 +629,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" pbar.set_description(description) - shared.state.textinfo = description if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' @@ -701,7 +700,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, torch.cuda.set_rng_state_all(cuda_rng_state) hypernetwork.train() if image is not None: - shared.state.current_image = image + shared.state.assign_current_image(image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" diff --git a/modules/img2img.py b/modules/img2img.py index f62783c6..f4a03c57 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_batch = mode == 5 if mode == 0: # img2img diff --git a/modules/progress.py b/modules/progress.py new file mode 100644 index 00000000..3327b883 --- /dev/null +++ b/modules/progress.py @@ -0,0 +1,96 @@ +import base64 +import io +import time + +import gradio as gr +from pydantic import BaseModel, Field + +from modules.shared import opts + +import modules.shared as shared + + +current_task = None +pending_tasks = {} +finished_tasks = [] + + +def start_task(id_task): + global current_task + + current_task = id_task + pending_tasks.pop(id_task, None) + + +def finish_task(id_task): + global current_task + + if current_task == id_task: + current_task = None + + finished_tasks.append(id_task) + if len(finished_tasks) > 16: + finished_tasks.pop(0) + + +def add_task_to_queue(id_job): + pending_tasks[id_job] = time.time() + + +class ProgressRequest(BaseModel): + id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") + id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image") + + +class ProgressResponse(BaseModel): + active: bool = Field(title="Whether the task is being worked on right now") + queued: bool = Field(title="Whether the task is in queue") + completed: bool = Field(title="Whether the task has already finished") + progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") + eta: float = Field(default=None, title="ETA in secs") + live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri") + id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image") + textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") + + +def setup_progress_api(app): + return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) + + +def progressapi(req: ProgressRequest): + active = req.id_task == current_task + queued = req.id_task in pending_tasks + completed = req.id_task in finished_tasks + + if not active: + return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...") + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + progress = min(progress, 1) + + elapsed_since_start = time.time() - shared.state.time_start + predicted_duration = elapsed_since_start / progress if progress > 0 else None + eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None + + id_live_preview = req.id_live_preview + shared.state.set_current_image() + if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview: + image = shared.state.current_image + if image is not None: + buffered = io.BytesIO() + image.save(buffered, format="png") + live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii") + id_live_preview = shared.state.id_live_preview + else: + live_preview = None + else: + live_preview = None + + return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) + diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 7616fded..76e0e0d5 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -140,7 +140,7 @@ def store_latent(decoded): if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: - shared.state.current_image = sample_to_image(decoded) + shared.state.assign_current_image(sample_to_image(decoded)) class InterruptedException(BaseException): diff --git a/modules/shared.py b/modules/shared.py index 51df056c..de99aca9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -152,6 +152,7 @@ def reload_hypernetworks(): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + class State: skipped = False interrupted = False @@ -165,6 +166,7 @@ class State: current_latent = None current_image = None current_image_sampling_step = 0 + id_live_preview = 0 textinfo = None time_start = None need_restart = False @@ -207,6 +209,7 @@ class State: self.current_latent = None self.current_image = None self.current_image_sampling_step = 0 + self.id_live_preview = 0 self.skipped = False self.interrupted = False self.textinfo = None @@ -220,8 +223,8 @@ class State: devices.torch_gc() - """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" def set_current_image(self): + """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" if not parallel_processing_allowed: return @@ -234,12 +237,16 @@ class State: import modules.sd_samplers if opts.show_progress_grid: - self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) else: - self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) self.current_image_sampling_step = self.sampling_step + def assign_current_image(self, image): + self.current_image = image + self.id_live_preview += 1 + state = State() state.server_start = time.time() @@ -424,8 +431,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), })) options_templates.update(options_section(('ui', "User interface"), { - "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), @@ -446,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "Live previews"), { "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 3c1042ad..64abff4d 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop -def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): +def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): try: if process_caption: shared.interrogator.load() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 63935878..7e4a6d24 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -345,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) @@ -510,7 +510,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" pbar.set_description(description) - shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' @@ -560,7 +559,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ shared.sd_model.first_stage_model.to(devices.cpu) if image is not None: - shared.state.current_image = image + shared.state.assign_current_image(image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" diff --git a/modules/txt2img.py b/modules/txt2img.py index 38b5f591..ca5d4550 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, diff --git a/modules/ui.py b/modules/ui.py index 2425c66f..ff33236b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -356,7 +356,7 @@ def create_toprow(is_img2img): button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1): - with gr.Row(): + with gr.Row(elem_id=f"{id_part}_generate_box"): skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') @@ -384,9 +384,7 @@ def create_toprow(is_img2img): def setup_progressbar(*args, **kwargs): - import modules.ui_progress - - modules.ui_progress.setup_progressbar(*args, **kwargs) + pass def apply_setting(key, value): @@ -479,8 +477,8 @@ Requested path was: {f} else: sp.Popen(["xdg-open", path]) - with gr.Column(variant='panel'): - with gr.Group(): + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) generation_info = None @@ -595,15 +593,6 @@ def create_ui(): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - with gr.Row().style(equal_height=False): with gr.Column(variant='panel', elem_id="txt2img_settings"): for category in ordered_ui_categories(): @@ -682,6 +671,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", inputs=[ + dummy_component, txt2img_prompt, txt2img_negative_prompt, txt2img_prompt_style, @@ -782,16 +772,7 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): @@ -958,6 +939,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", inputs=[ + dummy_component, dummy_component, img2img_prompt, img2img_negative_prompt, @@ -1335,15 +1317,11 @@ def create_ui(): script_callbacks.ui_train_tabs_callback(params) - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") + with gr.Column(elem_id='ti_gallery_container'): ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) ti_progress = gr.HTML(elem_id="ti_progress", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) create_embedding.click( fn=modules.textual_inversion.ui.create_embedding, @@ -1384,6 +1362,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, process_src, process_dst, process_width, @@ -1411,6 +1390,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, train_embedding_name, embedding_learn_rate, batch_size, @@ -1443,6 +1423,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, train_hypernetwork_name, hypernetwork_learn_rate, batch_size, diff --git a/modules/ui_progress.py b/modules/ui_progress.py deleted file mode 100644 index 7cd312e4..00000000 --- a/modules/ui_progress.py +++ /dev/null @@ -1,101 +0,0 @@ -import time - -import gradio as gr - -from modules.shared import opts - -import modules.shared as shared - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr.update(visible=False) - preview_visibility = gr.update(visible=False) - - if opts.live_previews_enable: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr.update(visible=True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr.update(visible=False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) diff --git a/style.css b/style.css index 2d484e06..786b71d1 100644 --- a/style.css +++ b/style.css @@ -305,26 +305,42 @@ input[type="range"]{ } .progressDiv{ - width: 100%; - height: 20px; - background: #b4c0cc; - border-radius: 8px; + position: absolute; + height: 20px; + top: -20px; + background: #b4c0cc; + border-radius: 8px !important; } .dark .progressDiv{ - background: #424c5b; + background: #424c5b; } .progressDiv .progress{ - width: 0%; - height: 20px; - background: #0060df; - color: white; - font-weight: bold; - line-height: 20px; - padding: 0 8px 0 0; - text-align: right; - border-radius: 8px; + width: 0%; + height: 20px; + background: #0060df; + color: white; + font-weight: bold; + line-height: 20px; + padding: 0 8px 0 0; + text-align: right; + border-radius: 8px; + overflow: visible; + white-space: nowrap; +} + +.livePreview{ + position: absolute; + z-index: 300; + background-color: white; + margin: -4px; +} + +.livePreview img{ + object-fit: contain; + width: 100%; + height: 100%; } #lightboxModal{ @@ -450,23 +466,25 @@ input[type="range"]{ display:none } -#txt2img_interrupt, #img2img_interrupt{ - position: absolute; - width: 50%; - height: 72px; - background: #b4c0cc; - border-radius: 0px; - display: none; +#txt2img_generate_box, #img2img_generate_box{ + position: relative; +} + +#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{ + position: absolute; + width: 50%; + height: 100%; + background: #b4c0cc; + display: none; } +#txt2img_interrupt, #img2img_interrupt{ + right: 0; + border-radius: 0 0.5rem 0.5rem 0; +} #txt2img_skip, #img2img_skip{ - position: absolute; - width: 50%; - right: 0px; - height: 72px; - background: #b4c0cc; - border-radius: 0px; - display: none; + left: 0; + border-radius: 0.5rem 0 0 0.5rem; } .red { diff --git a/webui.py b/webui.py index 1fff80da..4624fe18 100644 --- a/webui.py +++ b/webui.py @@ -34,6 +34,7 @@ import modules.sd_vae import modules.txt2img import modules.script_callbacks import modules.textual_inversion.textual_inversion +import modules.progress import modules.ui from modules import modelloader @@ -181,6 +182,8 @@ def webui(): app.add_middleware(GZipMiddleware, minimum_size=1000) + modules.progress.setup_progress_api(app) + if launch_api: create_api(app) -- cgit v1.2.3 From 388708f7b13dfbc890135cad678bfbcebd7baf37 Mon Sep 17 00:00:00 2001 From: pangbo13 <373108669@qq.com> Date: Mon, 16 Jan 2023 00:56:24 +0800 Subject: fix when show_progress_every_n_steps == -1 --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index de99aca9..f857ccde 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -228,7 +228,7 @@ class State: if not parallel_processing_allowed: return - if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable: + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1: self.do_set_current_image() def do_set_current_image(self): -- cgit v1.2.3 From a534bdfc801e0c83e378dfaa2d04cf865d7109f9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 20:27:39 +0300 Subject: add setting for progressbar update period --- javascript/progressbar.js | 2 +- modules/shared.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 8f22c018..59173c83 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -208,7 +208,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre setTimeout(() => { fun(id_task, res.id_live_preview); - }, 500) + }, opts.live_preview_refresh_period || 500) }, function(){ removeProgressBar() }) diff --git a/modules/shared.py b/modules/shared.py index de99aca9..3483db1c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -455,6 +455,7 @@ options_templates.update(options_section(('ui', "Live previews"), { "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), + "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds") })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From 3db22e6ee45193559a2c3ba44ab672b067245f99 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 23:32:38 +0300 Subject: rename masking to inpaint in UI make inpaint go to the right place for users who don't have it in config string --- modules/shared.py | 2 +- modules/ui.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 3bdc375b..f06ae610 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -116,7 +116,7 @@ restricted_opts = { } ui_reorder_categories = [ - "masking", + "inpaint", "sampler", "dimensions", "cfg", diff --git a/modules/ui.py b/modules/ui.py index b3d4af3e..20b66165 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -570,9 +570,9 @@ def create_sampler_and_steps_selection(choices, tabname): def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))} - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)): yield category @@ -889,7 +889,7 @@ def create_ui(): with FormGroup(elem_id="img2img_script_container"): custom_inputs = modules.scripts.scripts_img2img.setup_ui() - elif category == "masking": + elif category == "inpaint": with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") -- cgit v1.2.3 From 3f887f7f61d69fa699a272166b79fdb787e9ce1d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 00:44:46 +0300 Subject: support old configs that say "auto" for ssd_vae change sd_vae_as_default to True by default as it's a more sensible setting --- modules/sd_vae.py | 6 ++++-- modules/shared.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index add5cecf..e9c6bb40 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -94,8 +94,10 @@ def resolve_vae(checkpoint_file): if shared.cmd_opts.vae_path is not None: return shared.cmd_opts.vae_path, 'from commandline argument' + is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) - if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"): + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): return vae_near_checkpoint, 'found near the checkpoint' if shared.opts.sd_vae == "None": @@ -105,7 +107,7 @@ def resolve_vae(checkpoint_file): if vae_from_options is not None: return vae_from_options, 'specified in settings' - if shared.opts.sd_vae != "Automatic": + if is_automatic: print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") return None, None diff --git a/modules/shared.py b/modules/shared.py index f06ae610..c5fc250e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -394,7 +394,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), - "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.3 From 064983c0adb00cd9e88d2f06f66c9a1d5bc116c3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 12:56:30 +0300 Subject: return an option to hide progressbar --- javascript/progressbar.js | 1 + modules/shared.py | 1 + 2 files changed, 2 insertions(+) (limited to 'modules/shared.py') diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 5072c13f..da6709bc 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -121,6 +121,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var divProgress = document.createElement('div') divProgress.className='progressDiv' + divProgress.style.display = opts.show_progressbar ? "" : "none" var divInner = document.createElement('div') divInner.className='progress' diff --git a/modules/shared.py b/modules/shared.py index c5fc250e..483c4c62 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -451,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), { })) options_templates.update(options_section(('ui', "Live previews"), { + "show_progressbar": OptionInfo(True, "Show progressbar"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), -- cgit v1.2.3 From c361b89026442f3412162657f330d500b803e052 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 17 Jan 2023 11:04:56 +0300 Subject: disable the new NaN check for the CI --- launch.py | 2 ++ modules/devices.py | 3 +++ modules/shared.py | 1 + 3 files changed, 6 insertions(+) (limited to 'modules/shared.py') diff --git a/launch.py b/launch.py index 715427fd..5afb2956 100644 --- a/launch.py +++ b/launch.py @@ -286,6 +286,8 @@ def tests(test_dir): sys.argv.append("./test/test_files/empty.pt") if "--skip-torch-cuda-test" not in sys.argv: sys.argv.append("--skip-torch-cuda-test") + if "--disable-nan-check" not in sys.argv: + sys.argv.append("--disable-nan-check") print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") diff --git a/modules/devices.py b/modules/devices.py index 6f034948..206184fb 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -113,6 +113,9 @@ class NansException(Exception): def test_for_nans(x, where): from modules import shared + if shared.cmd_opts.disable_nan_check: + return + if not torch.all(torch.isnan(x)).item(): return diff --git a/modules/shared.py b/modules/shared.py index 483c4c62..a708f23c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -64,6 +64,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") +parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) -- cgit v1.2.3 From 6e08da2c315c346225aa834017f4e32cfc0de200 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Tue, 17 Jan 2023 23:50:41 +0900 Subject: Add `--vae-dir` argument --- modules/sd_vae.py | 7 +++++++ modules/shared.py | 1 + 2 files changed, 8 insertions(+) (limited to 'modules/shared.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index b2af2ce7..da1bf15c 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -72,6 +72,13 @@ def refresh_vae_list(): os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), ] + if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir): + paths += [ + os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'), + os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'), + os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'), + ] + candidates = [] for path in paths: candidates += glob.iglob(path, recursive=True) diff --git a/modules/shared.py b/modules/shared.py index a708f23c..a1345ad3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -26,6 +26,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") +parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with stable VAE files") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") -- cgit v1.2.3 From d906f87043d809e6d4d8de3c9926e184169b330f Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Wed, 18 Jan 2023 07:52:10 +0900 Subject: fix typo --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index a1345ad3..a42279ec 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -26,7 +26,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") -parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with stable VAE files") +parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") -- cgit v1.2.3 From 924e222004ab54273806c5f2ca7a0e7cfa76ad83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 18 Jan 2023 23:04:24 +0300 Subject: add option to show/hide warnings removed hiding warnings from LDSR fixed/reworked few places that produced warnings --- extensions-builtin/LDSR/ldsr_model_arch.py | 3 -- javascript/localization.js | 2 +- modules/hypernetworks/hypernetwork.py | 7 ++++- modules/sd_hijack.py | 8 ------ modules/sd_hijack_checkpoint.py | 38 +++++++++++++++++++++++++- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 6 +++- modules/ui.py | 31 ++++++++++++--------- scripts/prompts_from_file.py | 2 +- style.css | 5 ++-- 10 files changed, 71 insertions(+), 32 deletions(-) (limited to 'modules/shared.py') diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 0ad49f4e..bc11cc6e 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -1,7 +1,6 @@ import os import gc import time -import warnings import numpy as np import torch @@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap from modules import shared, sd_hijack -warnings.filterwarnings("ignore", category=UserWarning) - cached_ldsr_model: torch.nn.Module = None diff --git a/javascript/localization.js b/javascript/localization.js index bf9e1506..1a5a1dbb 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -11,7 +11,7 @@ ignore_ids_for_localization={ train_embedding: 'OPTION', train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', - img2img_styles 'OPTION', + img2img_styles: 'OPTION', setting_random_artist_categories: 'SPAN', setting_face_restoration_model: 'SPAN', setting_realesrgan_enabled_models: 'SPAN', diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c963fc40..74e78582 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes +from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi pbar = tqdm.tqdm(total=steps - initial_step) try: + sd_hijack_checkpoint.add() + for i in range((steps-initial_step) * gradient_step): if scheduler.finished: break @@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() hypernetwork.eval() #report_statistics(loss_dict) + sd_hijack_checkpoint.remove() + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6b0d95af..870eba88 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -69,12 +69,6 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -def fix_checkpoint(): - ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward - ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward - ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward - - class StableDiffusionModelHijack: fixes = None comments = [] @@ -106,8 +100,6 @@ class StableDiffusionModelHijack: self.optimization_method = apply_optimizations() self.clip = m.cond_stage_model - - fix_checkpoint() def flatten(el): flattened = [flatten(children) for children in el.children()] diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py index 5712972f..2604d969 100644 --- a/modules/sd_hijack_checkpoint.py +++ b/modules/sd_hijack_checkpoint.py @@ -1,10 +1,46 @@ from torch.utils.checkpoint import checkpoint +import ldm.modules.attention +import ldm.modules.diffusionmodules.openaimodel + + def BasicTransformerBlock_forward(self, x, context=None): return checkpoint(self._forward, x, context) + def AttentionBlock_forward(self, x): return checkpoint(self._forward, x) + def ResBlock_forward(self, x, emb): - return checkpoint(self._forward, x, emb) \ No newline at end of file + return checkpoint(self._forward, x, emb) + + +stored = [] + + +def add(): + if len(stored) != 0: + return + + stored.extend([ + ldm.modules.attention.BasicTransformerBlock.forward, + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward + ]) + + ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward + + +def remove(): + if len(stored) == 0: + return + + ldm.modules.attention.BasicTransformerBlock.forward = stored[0] + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] + + stored.clear() + diff --git a/modules/shared.py b/modules/shared.py index a708f23c..ddb97f99 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -369,6 +369,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" })) options_templates.update(options_section(('system', "System"), { + "show_warnings": OptionInfo(False, "Show warnings in console."), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7e4a6d24..5a7be422 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -15,7 +15,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st pbar = tqdm.tqdm(total=steps - initial_step) try: + sd_hijack_checkpoint.add() + for i in range((steps-initial_step) * gradient_step): if scheduler.finished: break @@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() shared.sd_model.first_stage_model.to(devices.device) shared.parallel_processing_allowed = old_parallel_processing_allowed + sd_hijack_checkpoint.remove() return embedding, filename + def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): old_embedding_name = embedding.name old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None diff --git a/modules/ui.py b/modules/ui.py index 6d70a795..25818fb0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -11,6 +11,7 @@ import tempfile import time import traceback from functools import partial, reduce +import warnings import gradio as gr import gradio.routes @@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) + # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') @@ -417,17 +420,16 @@ def apply_setting(key, value): return value -def update_generation_info(args): - generation_info, html_info, img_index = args +def update_generation_info(generation_info, html_info, img_index): try: generation_info = json.loads(generation_info) if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() except Exception: pass # if the json parse or anything else fails, just return the old html_info - return html_info + return html_info, gr.update() def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): @@ -508,10 +510,9 @@ Requested path was: {f} generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False + _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], ) save.click( @@ -526,7 +527,8 @@ Requested path was: {f} outputs=[ download_files, html_log, - ] + ], + show_progress=False, ) save_zip.click( @@ -588,7 +590,7 @@ def create_ui(): txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): @@ -768,7 +770,7 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): @@ -1768,7 +1770,10 @@ def create_ui(): if saved_value is None: ui_settings[key] = getattr(obj, field) elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + pass + + # this warning is generally not useful; + # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') else: setattr(obj, field, saved_value) if init_field is not None: diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index f3e711d7..76dc5778 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -116,7 +116,7 @@ class Script(scripts.Script): checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) - file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file")) + file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) diff --git a/style.css b/style.css index 61279a19..0845519a 100644 --- a/style.css +++ b/style.css @@ -299,9 +299,8 @@ input[type="range"]{ } /* more gradio's garbage cleanup */ -.min-h-\[4rem\] { - min-height: unset !important; -} +.min-h-\[4rem\] { min-height: unset !important; } +.min-h-\[6rem\] { min-height: unset !important; } .progressDiv{ position: absolute; -- cgit v1.2.3 From 0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 10:39:51 +0300 Subject: allow baking in VAE in checkpoint merger tab do not save config if it's the default for checkpoint merger tab change file naming scheme for checkpoint merger tab allow just saving A without any merging for checkpoint merger tab some stylistic changes for UI in checkpoint merger tab --- javascript/hints.js | 1 + javascript/ui.js | 2 - modules/extras.py | 112 +++++++++++++++++++++++++++++++--------------------- modules/sd_vae.py | 9 ++++- modules/shared.py | 3 +- modules/ui.py | 17 ++++++-- style.css | 15 ++++--- 7 files changed, 101 insertions(+), 58 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/hints.js b/javascript/hints.js index fa5e5ae8..e746e20d 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -92,6 +92,7 @@ titles = { "Weighted sum": "Result = A * (1 - M) + B * M", "Add difference": "Result = A + (B - C) * M", + "No interpolation": "Result = A", "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors", "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", diff --git a/javascript/ui.js b/javascript/ui.js index 428375d4..37788a3e 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -176,8 +176,6 @@ function modelmerger(){ var id = randomId() requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){}) - gradioApp().getElementById('modelmerger_result').innerHTML = '' - var res = create_submit_args(arguments) res[0] = id return res diff --git a/modules/extras.py b/modules/extras.py index 034f28e4..fe701a0e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple from functools import partial from dataclasses import dataclass -from modules import processing, shared, images, devices, sd_models, sd_samplers +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae from modules.shared import opts import modules.gfpgan_model from modules.ui import plaintext_to_html @@ -251,7 +251,8 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c): def config(x): - return sd_models.find_checkpoint_config(x) if x else None + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None if config_source == 0: cfg = config(a) or config(b) or config(c) @@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename) -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): +chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae): shared.state.begin() shared.state.job = 'model-merge' - shared.state.job_count = 1 def fail(message): shared.state.textinfo = message @@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ def add_difference(theta0, theta1_2_diff, alpha): return theta0 + (alpha * theta1_2_diff) + def filename_weighed_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_differnece(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighed_sum, None, weighted_sum), + "Add difference": (filename_add_differnece, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + if not primary_model_name: return fail("Failed: Merging requires a primary model.") primary_model_info = sd_models.checkpoints_list[primary_model_name] - if not secondary_model_name: + if theta_func2 and not secondary_model_name: return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] - theta_funcs = { - "Weighted sum": (None, weighted_sum), - "Add difference": (get_difference, add_difference), - } - theta_func1, theta_func2 = theta_funcs[interp_method] + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None if theta_func1 and not tertiary_model_name: return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False - shared.state.textinfo = f"Loading {secondary_model_info.filename}..." - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None if theta_func1: - shared.state.job_count += 1 - + shared.state.textinfo = f"Loading C" print(f"Loading {tertiary_model_info.filename}...") theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + shared.state.textinfo = 'Merging B and C' shared.state.sampling_steps = len(theta_1.keys()) for key in tqdm.tqdm(theta_1.keys()): + if key in chckpoint_dict_skip_on_merge: + continue + if 'model' in key: if key in theta_2: t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) @@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') print("Merging...") - - chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - + shared.state.textinfo = 'Merging A and B' shared.state.sampling_steps = len(theta_0.keys()) for key in tqdm.tqdm(theta_0.keys()): - if 'model' in key and key in theta_1: + if theta_1 and 'model' in key and key in theta_1: if key in chckpoint_dict_skip_on_merge: continue @@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ a = theta_0[key] b = theta_1[key] - shared.state.textinfo = f'Merging layer {key}' # this enables merging an inpainting model (A) with another one (B); # where normal model would have 4 channels, for latenst space, inpainting model would # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 @@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.sampling_step += 1 - # I believe this part should be discarded, but I'll leave it for now until I am sure - for key in theta_1.keys(): - if 'model' in key and key not in theta_0: + del theta_1 + + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - if key in chckpoint_dict_skip_on_merge: - continue + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key] - theta_0[key] = theta_1[key] - if save_as_half: - theta_0[key] = theta_0[key].half() - del theta_1 + del vae_dict ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - filename = \ - primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \ - secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \ - interp_method.replace(" ", "_") + \ - '-merged.' + \ - ("inpainting." if result_is_inpainting_model else "") + \ - checkpoint_format - - filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format) + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) shared.state.nextjob() - shared.state.textinfo = f"Saving to {output_modelname}..." + shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") _, extension = os.path.splitext(output_modelname) @@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - print("Checkpoint saved.") - shared.state.textinfo = "Checkpoint saved to " + output_modelname + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" shared.state.end() return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/sd_vae.py b/modules/sd_vae.py index da1bf15c..4ce238b8 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file): return None, None +def load_vae_dict(filename, map_location): + vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + return vae_dict_1 + + def load_vae(model, vae_file=None, vae_source="from unknown source"): global vae_dict, loaded_vae_file # save_settings = False @@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) - vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) _load_vae_dict(model, vae_dict_1) if cache_enabled: diff --git a/modules/shared.py b/modules/shared.py index 77e5e91c..29b28bff 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -20,10 +20,11 @@ from modules.paths import models_path, script_path, sd_path demo = None +sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files") diff --git a/modules/ui.py b/modules/ui.py index aeee7853..4e381a49 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -1185,7 +1185,7 @@ def create_ui(): with gr.Column(variant='compact'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - with FormRow(): + with FormRow(elem_id="modelmerger_models"): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") @@ -1197,13 +1197,20 @@ def create_ui(): custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + with FormRow(): + with gr.Column(): + config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + with gr.Column(): + with FormRow(): + bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") + create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") with gr.Row(): modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') @@ -1757,6 +1764,7 @@ def create_ui(): return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results + modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result]) modelmerger_merge.click( fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), _js='modelmerger', @@ -1771,6 +1779,7 @@ def create_ui(): custom_name, checkpoint_format, config_source, + bake_in_vae, ], outputs=[ primary_model_name, diff --git a/style.css b/style.css index 32ba4753..c10e32a1 100644 --- a/style.css +++ b/style.css @@ -641,6 +641,16 @@ canvas[key="mask"] { margin: 0.6em 0em 0.55em 0; } +#modelmerger_results_container{ + margin-top: 1em; + overflow: visible; +} + +#modelmerger_models{ + gap: 0; +} + + #quicksettings .gr-button-tool{ margin: 0; } @@ -737,11 +747,6 @@ footer { line-height: 2.4em; } -#modelmerger_results_container{ - margin-top: 1em; - overflow: visible; -} - /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From b271e22f7ac1b2cabca8985b1e4437ab685a2c21 Mon Sep 17 00:00:00 2001 From: vt-idiot <81622808+vt-idiot@users.noreply.github.com> Date: Thu, 19 Jan 2023 06:12:19 -0500 Subject: Update shared.py `Witdth/Height` was driving me insane. -> `Width/Height` --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 29b28bff..2f366454 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -448,7 +448,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), - "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"), + "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), -- cgit v1.2.3 From 40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 08:36:07 +0300 Subject: extra networks UI rework of hypernets: rather than via settings, hypernets are added directly to prompt as --- html/card-no-preview.png | Bin 0 -> 84440 bytes html/extra-networks-card.html | 11 ++ html/extra-networks-no-cards.html | 8 ++ javascript/extraNetworks.js | 60 ++++++++ javascript/hints.js | 2 + javascript/ui.js | 9 +- modules/api/api.py | 7 +- modules/extra_networks.py | 147 +++++++++++++++++++ modules/extra_networks_hypernet.py | 21 +++ modules/generation_parameters_copypaste.py | 12 +- modules/hypernetworks/hypernetwork.py | 107 +++++++++----- modules/hypernetworks/ui.py | 5 +- modules/processing.py | 24 ++-- modules/sd_hijack_optimizations.py | 10 +- modules/shared.py | 21 ++- modules/textual_inversion/textual_inversion.py | 2 + modules/ui.py | 50 ++++--- modules/ui_components.py | 10 ++ modules/ui_extra_networks.py | 149 +++++++++++++++++++ modules/ui_extra_networks_hypernets.py | 34 +++++ modules/ui_extra_networks_textual_inversion.py | 32 +++++ script.js | 13 +- scripts/xy_grid.py | 29 ---- style.css | 190 +++++++++++++------------ webui.py | 26 +++- 25 files changed, 765 insertions(+), 214 deletions(-) create mode 100644 html/card-no-preview.png create mode 100644 html/extra-networks-card.html create mode 100644 html/extra-networks-no-cards.html create mode 100644 javascript/extraNetworks.js create mode 100644 modules/extra_networks.py create mode 100644 modules/extra_networks_hypernet.py create mode 100644 modules/ui_extra_networks.py create mode 100644 modules/ui_extra_networks_hypernets.py create mode 100644 modules/ui_extra_networks_textual_inversion.py (limited to 'modules/shared.py') diff --git a/html/card-no-preview.png b/html/card-no-preview.png new file mode 100644 index 00000000..e2beb269 Binary files /dev/null and b/html/card-no-preview.png differ diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html new file mode 100644 index 00000000..7314b063 --- /dev/null +++ b/html/extra-networks-card.html @@ -0,0 +1,11 @@ +
+
+
+ +
+ {name} +
+
+ diff --git a/html/extra-networks-no-cards.html b/html/extra-networks-no-cards.html new file mode 100644 index 00000000..389358d6 --- /dev/null +++ b/html/extra-networks-no-cards.html @@ -0,0 +1,8 @@ +
+

Nothing here. Add some content to the following directories:

+ +
    +{dirs} +
+
+ diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js new file mode 100644 index 00000000..71e522d1 --- /dev/null +++ b/javascript/extraNetworks.js @@ -0,0 +1,60 @@ + +function setupExtraNetworksForTab(tabname){ + gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') + + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh')) + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close')) +} + +var activePromptTextarea = null; +var activePositivePromptTextarea = null; + +function setupExtraNetworks(){ + setupExtraNetworksForTab('txt2img') + setupExtraNetworksForTab('img2img') + + function registerPrompt(id, isNegative){ + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); + + if (activePromptTextarea == null){ + activePromptTextarea = textarea + } + if (activePositivePromptTextarea == null && ! isNegative){ + activePositivePromptTextarea = textarea + } + + textarea.addEventListener("focus", function(){ + activePromptTextarea = textarea; + if(! isNegative) activePositivePromptTextarea = textarea; + }); + } + + registerPrompt('txt2img_prompt') + registerPrompt('txt2img_neg_prompt', true) + registerPrompt('img2img_prompt') + registerPrompt('img2img_neg_prompt', true) +} + +onUiLoaded(setupExtraNetworks) + +function cardClicked(textToAdd, allowNegativePrompt){ + textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea + + textarea.value = textarea.value + " " + textToAdd + updateInput(textarea) + + return false +} + +function saveCardPreview(event, tabname, filename){ + textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea') + button = gradioApp().getElementById(tabname + '_save_preview') + + textarea.value = filename + updateInput(textarea) + + button.click() + + event.stopPropagation() + event.preventDefault() +} diff --git a/javascript/hints.js b/javascript/hints.js index e746e20d..f4079f96 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -21,6 +21,8 @@ titles = { "\U0001F5D1": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", + "\u{1f3b4}": "Show extra networks", + "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/javascript/ui.js b/javascript/ui.js index 3ba90ca8..a7e75439 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) { return [prompt, negative_prompt] } - - opts = {} onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -239,11 +237,14 @@ onUiUpdate(function(){ return } + prompt.parentElement.insertBefore(counter, prompt) counter.classList.add("token-counter") prompt.parentElement.style.position = "relative" - textarea.addEventListener("input", () => update_token_counter(id_button)); + textarea.addEventListener("input", function(){ + update_token_counter(id_button); + }); } registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') @@ -261,10 +262,8 @@ onUiUpdate(function(){ }) } } - }) - onOptionsChanged(function(){ elem = gradioApp().getElementById('sd_checkpoint_hash') sd_checkpoint_hash = opts.sd_checkpoint_hash || "" diff --git a/modules/api/api.py b/modules/api/api.py index 9814bbc2..2c371e6e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -480,7 +480,7 @@ class Api: def train_hypernetwork(self, args: dict): try: shared.state.begin() - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -491,16 +491,15 @@ class Api: except Exception as e: error = e finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) except AssertionError as msg: shared.state.end() - return TrainResponse(info = "train embedding error: {error}".format(error = error)) + return TrainResponse(info="train embedding error: {error}".format(error=error)) def get_memory(self): try: diff --git a/modules/extra_networks.py b/modules/extra_networks.py new file mode 100644 index 00000000..1978673d --- /dev/null +++ b/modules/extra_networks.py @@ -0,0 +1,147 @@ +import re +from collections import defaultdict + +from modules import errors + +extra_network_registry = {} + + +def initialize(): + extra_network_registry.clear() + + +def register_extra_network(extra_network): + extra_network_registry[extra_network.name] = extra_network + + +class ExtraNetworkParams: + def __init__(self, items=None): + self.items = items or [] + + +class ExtraNetwork: + def __init__(self, name): + self.name = name + + def activate(self, p, params_list): + """ + Called by processing on every run. Whatever the extra network is meant to do should be activated here. + Passes arguments related to this extra network in params_list. + User passes arguments by specifying this in his prompt: + + + + Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments + separated by colon. + + Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - + in this case, all effects of this extra networks should be disabled. + + Can be called multiple times before deactivate() - each new call should override the previous call completely. + + For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is: + + > "1girl, " + + params_list will be: + + [ + ExtraNetworkParams(items=["agm", "1.1"]), + ExtraNetworkParams(items=["ray"]) + ] + + """ + raise NotImplementedError + + def deactivate(self, p): + """ + Called at the end of processing for housekeeping. No need to do anything here. + """ + + raise NotImplementedError + + +def activate(p, extra_network_data): + """call activate for extra networks in extra_network_data in specified order, then call + activate for all remaining registered networks with an empty argument list""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + print(f"Skipping unknown extra network: {extra_network_name}") + continue + + try: + extra_network.activate(p, extra_network_args) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.activate(p, []) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name}") + + +def deactivate(p, extra_network_data): + """call deactivate for extra networks in extra_network_data in specified order, then call + deactivate for all remaining registered networks""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating extra network {extra_network_name}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating unmentioned extra network {extra_network_name}") + + +re_extra_net = re.compile(r"<(\w+):([^>]+)>") + + +def parse_prompt(prompt): + res = defaultdict(list) + + def found(m): + name = m.group(1) + args = m.group(2) + + res[name].append(ExtraNetworkParams(items=args.split(":"))) + + return "" + + prompt = re.sub(re_extra_net, found, prompt) + + return prompt, res + + +def parse_prompts(prompts): + res = [] + extra_data = None + + for prompt in prompts: + updated_prompt, parsed_extra_data = parse_prompt(prompt) + + if extra_data is None: + extra_data = parsed_extra_data + + res.append(updated_prompt) + + return res, extra_data + diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py new file mode 100644 index 00000000..6a0c4ba8 --- /dev/null +++ b/modules/extra_networks_hypernet.py @@ -0,0 +1,21 @@ +from modules import extra_networks +from modules.hypernetworks import hypernetwork + + +class ExtraNetworkHypernet(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('hypernet') + + def activate(self, p, params_list): + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + hypernetwork.load_hypernetworks(names, multipliers) + + def deactivate(p, self): + pass diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a381ff59..46e12dc6 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict): from modules import ui settings_map = { - 'sd_hypernetwork': 'Hypernet', - 'sd_hypernetwork_strength': 'Hypernet strength', 'CLIP_stop_at_last_layers': 'Clip skip', 'inpainting_mask_weight': 'Conditional mask weight', 'sd_model_checkpoint': 'Model hash', @@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Clip skip" not in res: res["Clip skip"] = "1" - if "Hypernet strength" not in res: - res["Hypernet strength"] = "1" - - if "Hypernet" in res: - hypernet_name = res["Hypernet"] - hypernet_hash = res.get("Hypernet hash", None) - res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res["Prompt"] += f"""""" if "Hires resize-1" not in res: res["Hires resize-1"] = 0 diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74e78582..80a47c79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -25,7 +25,6 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} class HypernetworkModule(torch.nn.Module): - multiplier = 1.0 activation_dict = { "linear": torch.nn.Identity, "relu": torch.nn.ReLU, @@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module): add_layer_norm=False, activate_output=False, dropout_structure=None): super().__init__() + self.multiplier = 1.0 + assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" @@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x def forward(self, x): - return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) + return x + self.linear(x) * (self.multiplier if not self.training else 1) def trainables(self): layer_structure = [] @@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module): return layer_structure -def apply_strength(value=None): - HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength - #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check. def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): if layer_structure is None: @@ -192,6 +190,20 @@ class Hypernetwork: for param in layer.parameters(): param.requires_grad = mode + def to(self, device): + for k, layers in self.layers.items(): + for layer in layers: + layer.to(device) + + return self + + def set_multiplier(self, multiplier): + for k, layers in self.layers.items(): + for layer in layers: + layer.multiplier = multiplier + + return self + def eval(self): for k, layers in self.layers.items(): for layer in layers: @@ -269,11 +281,13 @@ class Hypernetwork: self.optimizer_state_dict = None if self.optimizer_state_dict: self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') - print("Loaded existing optimizer from checkpoint") - print(f"Optimizer name is {self.optimizer_name}") + if shared.opts.print_hypernet_extra: + print("Loaded existing optimizer from checkpoint") + print(f"Optimizer name is {self.optimizer_name}") else: self.optimizer_name = "AdamW" - print("No saved optimizer exists in checkpoint") + if shared.opts.print_hypernet_extra: + print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: @@ -306,23 +320,43 @@ def list_hypernetworks(path): return res -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - # Prevent any file named "None.pt" from being loaded. - if path is not None and filename != "None": - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) +def load_hypernetwork(name): + path = shared.hypernetworks.get(name, None) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print("Unloading hypernetwork") + if path is None: + return None + + hypernetwork = Hypernetwork() + + try: + hypernetwork.load(path) + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + return hypernetwork + + +def load_hypernetworks(names, multipliers=None): + already_loaded = {} + + for hypernetwork in shared.loaded_hypernetworks: + if hypernetwork.name in names: + already_loaded[hypernetwork.name] = hypernetwork - shared.loaded_hypernetwork = None + shared.loaded_hypernetworks.clear() + + for i, name in enumerate(names): + hypernetwork = already_loaded.get(name, None) + if hypernetwork is None: + hypernetwork = load_hypernetwork(name) + + if hypernetwork is None: + continue + + hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0) + shared.loaded_hypernetworks.append(hypernetwork) def find_closest_hypernetwork_name(search: str): @@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str): return applicable[0] -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) +def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) if hypernetwork_layers is None: - return context, context + return context_k, context_v if layer is not None: layer.hyper_k = hypernetwork_layers[0] layer.hyper_v = hypernetwork_layers[1] - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) + context_k = hypernetwork_layers[0](context_k) + context_v = hypernetwork_layers[1](context_v) + return context_k, context_v + + +def apply_hypernetworks(hypernetworks, context, layer=None): + context_k = context + context_v = context + for hypernetwork in hypernetworks: + context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer) + return context_k, context_v @@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self) k = self.to_k(context_k) v = self.to_v(context_v) @@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) + hypernetwork = Hypernetwork() + hypernetwork.load(path) + shared.loaded_hypernetworks = [hypernetwork] shared.state.job = "train-hypernetwork" shared.state.textinfo = "Initializing hypernetwork training..." @@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else: images_dir = None - hypernetwork = shared.loaded_hypernetwork checkpoint = sd_models.select_checkpoint() initial_step = hypernetwork.step or 0 diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 81e3f519..76599f5a 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"] keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) + def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) @@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(*args): - - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' @@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) sd_hijack.apply_optimizations() diff --git a/modules/processing.py b/modules/processing.py index a3e9f709..b5deeacf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), - "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()), - "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), @@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': - shared.reload_hypernetworks() # make onchange call for changing hypernet if k == 'sd_model_checkpoint': - sd_models.reload_model_weights() # make onchange call for changing SD model + sd_models.reload_model_weights() if k == 'sd_vae': - sd_vae.reload_vae_weights() # make onchange call for changing VAE + sd_vae.reload_vae_weights() res = process_images_inner(p) @@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': shared.reload_hypernetworks() - if k == 'sd_model_checkpoint': sd_models.reload_model_weights() - if k == 'sd_vae': sd_vae.reload_vae_weights() + if k == 'sd_model_checkpoint': + sd_models.reload_model_weights() + + if k == 'sd_vae': + sd_vae.reload_vae_weights() return res @@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: cache[0] = (required_prompts, steps) return cache[1] + p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts) + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + extra_networks.activate(p, extra_network_data) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) @@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + extra_networks.deactivate(p, extra_network_data) devices.torch_gc() res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index cdc63ed7..4fa54329 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) del context, context_k, context_v, x @@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) @@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) * self.scale v = self.to_v(context_v) del context, context_k, context_v, x @@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x @@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) diff --git a/modules/shared.py b/modules/shared.py index 2f366454..c0e11f18 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -23,6 +23,7 @@ demo = None sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file + parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) @@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = {} -loaded_hypernetwork = None +loaded_hypernetworks = [] def reload_hypernetworks(): @@ -153,8 +154,6 @@ def reload_hypernetworks(): global hypernetworks hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) - hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - class State: @@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), - "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), - "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), @@ -661,3 +658,17 @@ mem_mon.start() def listfiles(dirname): filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] + + +def html_path(filename): + return os.path.join(script_path, "html", filename) + + +def html(filename): + path = html_path(filename) + + if os.path.exists(path): + with open(path, encoding="utf8") as file: + return file.read() + + return "" diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5a7be422..4e90f690 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -50,6 +50,7 @@ class Embedding: self.sd_checkpoint = None self.sd_checkpoint_name = None self.optimizer_state_dict = None + self.filename = None def save(self, filename): embedding_data = { @@ -182,6 +183,7 @@ class EmbeddingDatabase: embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] + embedding.filename = path if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) diff --git a/modules/ui.py b/modules/ui.py index 06c11848..d23b2b8e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001F5D1' # 🗑️ +extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): @@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps): try: + text, _ = extra_networks.parse_prompt(text) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) @@ -354,10 +357,10 @@ def create_toprow(is_img2img): negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + paste = ToolButton(value=paste_symbol, elem_id="paste") + clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") @@ -395,11 +398,14 @@ def create_toprow(is_img2img): outputs=[], ) - with gr.Row(): + with gr.Row(elem_id=f"{id_part}_styles_row"): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button + prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply") + save_style = ToolButton(value=save_style_symbol, elem_id="style_create") + + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button def setup_progressbar(*args, **kwargs): @@ -616,11 +622,15 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): for category in ordered_ui_categories(): @@ -794,14 +804,20 @@ def create_ui(): token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') + with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): copy_image_buttons = [] @@ -1064,6 +1080,8 @@ def create_ui(): token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + img2img_paste_fields = [ (img2img_prompt, "Prompt"), (img2img_negative_prompt, "Negative prompt"), @@ -1666,10 +1684,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") + with gr.TabItem("Licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.Button(value="Show all pages", elem_id="settings_show_all_pages") @@ -1756,11 +1772,9 @@ def create_ui(): if os.path.exists(os.path.join(script_path, "notification.mp3")): audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") + footer = shared.html("footer.html") + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( diff --git a/modules/ui_components.py b/modules/ui_components.py index 97acff06..46324425 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent): return "button" +class ToolButtonTop(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool-top", **kwargs) + + def get_block_name(self): + return "button" + + class FormRow(gr.Row, gr.components.FormComponent): """Same as gr.Row but fits inside gradio forms""" diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 00000000..253e90f7 --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,149 @@ +import os.path + +from modules import shared +import gradio as gr +import json + +from modules.generation_parameters_copypaste import image_from_url_text + +extra_pages = [] + + +def register_page(page): + """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions""" + + extra_pages.append(page) + + +class ExtraNetworksPage: + def __init__(self, title): + self.title = title + self.card_page = shared.html("extra-networks-card.html") + self.allow_negative_prompt = False + + def refresh(self): + pass + + def create_html(self, tabname): + items_html = '' + + for item in self.list_items(): + items_html += self.create_html_for_item(item, tabname) + + if items_html == '': + dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) + items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) + + res = "
    " + items_html + "
    " + + return res + + def list_items(self): + raise NotImplementedError() + + def allowed_directories_for_previews(self): + return [] + + def create_html_for_item(self, item, tabname): + preview = item.get("preview", None) + + args = { + "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', + "prompt": json.dumps(item["prompt"]), + "tabname": json.dumps(tabname), + "local_preview": json.dumps(item["local_preview"]), + "name": item["name"], + "allow_negative_prompt": "true" if self.allow_negative_prompt else "false", + } + + return self.card_page.format(**args) + + +def intialize(): + extra_pages.clear() + + +class ExtraNetworksUi: + def __init__(self): + self.pages = None + self.stored_extra_pages = None + + self.button_save_preview = None + self.preview_target_filename = None + + self.tabname = None + + +def create_ui(container, button, tabname): + ui = ExtraNetworksUi() + ui.pages = [] + ui.stored_extra_pages = extra_pages.copy() + ui.tabname = tabname + + with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + + for page in ui.stored_extra_pages: + with gr.Tab(page.title): + page_elem = gr.HTML(page.create_html(ui.tabname)) + ui.pages.append(page_elem) + + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + + button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) + button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) + + def refresh(): + res = [] + + for pg in ui.stored_extra_pages: + pg.refresh() + res.append(pg.create_html(ui.tabname)) + + return res + + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + + return ui + + +def path_is_parent(parent_path, child_path): + parent_path = os.path.abspath(parent_path) + child_path = os.path.abspath(child_path) + + return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) + + +def setup_ui(ui, gallery): + def save_preview(index, images, filename): + if len(images) == 0: + print("There is no image in gallery to save as a preview.") + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + index = int(index) + index = 0 if index < 0 else index + index = len(images) - 1 if index >= len(images) else index + + img_info = images[index if index >= 0 else 0] + image = image_from_url_text(img_info) + + is_allowed = False + for extra_page in ui.stored_extra_pages: + if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]): + is_allowed = True + break + + assert is_allowed, f'writing to {filename} is not allowed' + + image.save(filename) + + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + ui.button_save_preview.click( + fn=save_preview, + _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + outputs=[*ui.pages] + ) diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 00000000..312dbaf0 --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,34 @@ +import os + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Hypernetworks') + + def refresh(self): + shared.reload_hypernetworks() + + def list_items(self): + for name, path in shared.hypernetworks.items(): + path, ext = os.path.splitext(path) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "prompt": f"", + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.hypernetwork_dir] + diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 00000000..e4a6e3bf --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,32 @@ +import os + +from modules import ui_extra_networks, sd_hijack + + +class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Textual Inversion') + self.allow_negative_prompt = True + + def refresh(self): + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + + def list_items(self): + for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): + path, ext = os.path.splitext(embedding.filename) + preview_file = path + ".preview.png" + + preview = None + if os.path.isfile(preview_file): + preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file)) + + yield { + "name": embedding.name, + "filename": embedding.filename, + "preview": preview, + "prompt": embedding.name, + "local_preview": path + ".preview.png", + } + + def allowed_directories_for_previews(self): + return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/script.js b/script.js index 3345e32b..97e0bfcf 100644 --- a/script.js +++ b/script.js @@ -13,6 +13,7 @@ function get_uiCurrentTabContent() { } uiUpdateCallbacks = [] +uiLoadedCallbacks = [] uiTabChangeCallbacks = [] optionsChangedCallbacks = [] let uiCurrentTab = null @@ -20,6 +21,9 @@ let uiCurrentTab = null function onUiUpdate(callback){ uiUpdateCallbacks.push(callback) } +function onUiLoaded(callback){ + uiLoadedCallbacks.push(callback) +} function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } @@ -38,8 +42,15 @@ function executeCallbacks(queue, m) { queue.forEach(function(x){runCallback(x, m)}) } +var executedOnLoaded = false; + document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ + if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){ + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); + } + executeCallbacks(uiUpdateCallbacks, m); const newTab = get_uiCurrentTab(); if ( newTab && ( newTab !== uiCurrentTab ) ) { @@ -53,7 +64,7 @@ document.addEventListener("DOMContentLoaded", function() { /** * Add a ctrl+enter as a shortcut to start a generation */ - document.addEventListener('keydown', function(e) { +document.addEventListener('keydown', function(e) { var handled = false; if (e.key !== undefined) { if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6629f5d5..b1badec9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,6 @@ import modules.scripts as scripts import gradio as gr from modules import images, paths, sd_samplers, processing, sd_models, sd_vae -from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs): raise RuntimeError(f"Unknown checkpoint: {x}") -def apply_hypernetwork(p, x, xs): - if x.lower() in ["", "none"]: - name = None - else: - name = hypernetwork.find_closest_hypernetwork_name(x) - if not name: - raise RuntimeError(f"Unknown hypernetwork: {x}") - hypernetwork.load_hypernetwork(name) - - -def apply_hypernetwork_strength(p, x, xs): - hypernetwork.apply_strength(x) - - -def confirm_hypernetworks(p, xs): - for x in xs: - if x.lower() in ["", "none"]: - continue - if not hypernetwork.find_closest_hypernetwork_name(x): - raise RuntimeError(f"Unknown hypernetwork: {x}") - - def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x @@ -208,8 +185,6 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma max", float, apply_field("s_tmax")), @@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ class SharedSettingsStackHelper(object): def __enter__(self): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - self.hypernetwork = opts.sd_hypernetwork self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object): modules.sd_models.reload_model_weights() modules.sd_vae.reload_vae_weights() - hypernetwork.load_hypernetwork(self.hypernetwork) - hypernetwork.apply_strength() - opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers diff --git a/style.css b/style.css index 3a515ebd..5e8bc2ca 100644 --- a/style.css +++ b/style.css @@ -132,13 +132,6 @@ } #roll_col > button { - min-width: 2em; - min-height: 2em; - max-width: 2em; - max-height: 2em; - flex-grow: 0; - padding-left: 0.25em; - padding-right: 0.25em; margin: 0.1em 0; } @@ -146,9 +139,10 @@ min-width: 0 !important; max-width: 8em !important; margin-right: 1em; + gap: 0; } #interrogate, #deepbooru{ - margin: 0em 0.25em 0.9em 0.25em; + margin: 0em 0.25em 0.5em 0.25em; min-width: 8em; max-width: 8em; } @@ -157,8 +151,17 @@ min-width: 8em !important; } +#txt2img_styles_row, #img2img_styles_row{ + gap: 0.25em; + margin-top: 0.5em; +} + +#txt2img_styles_row > button, #img2img_styles_row > button{ + margin: 0; +} + #txt2img_styles, #img2img_styles{ - margin-top: 1em; + padding: 0; } #txt2img_styles ul, #img2img_styles ul{ @@ -635,17 +638,21 @@ canvas[key="mask"] { background-color: rgb(31 41 55 / var(--tw-bg-opacity)); } -.gr-button-tool{ +.gr-button-tool, .gr-button-tool-top{ max-width: 2.5em; min-width: 2.5em !important; height: 2.4em; - margin: 1.6em 0.7em 0.55em 0; } -#tab_modelmerger .gr-button-tool{ +.gr-button-tool{ margin: 0.6em 0em 0.55em 0; } +.gr-button-tool-top, #settings .gr-button-tool{ + margin: 1.6em 0.7em 0.55em 0; +} + + #modelmerger_results_container{ margin-top: 1em; overflow: visible; @@ -763,81 +770,88 @@ footer { line-height: 2.4em; } -/* The following handles localization for right-to-left (RTL) languages like Arabic. -The rtl media type will only be activated by the logic in javascript/localization.js. -If you change anything above, you need to make sure it is RTL compliant by just running -your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/. -Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/ -@media rtl { - /* this part was added manually */ - :host { - direction: rtl; - } - select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress { - direction: ltr; - } - #script_list > label > select, - #x_type > label > select, - #y_type > label > select { - direction: rtl; - } - .gr-radio, .gr-checkbox{ - margin-left: 0.25em; - } +#txt2img_extra_networks, #img2img_extra_networks{ + margin-top: -1em; +} - /* automatically generated with few manual modifications */ - .performance .time { - margin-right: unset; - margin-left: 0; - } - .justify-center.overflow-x-scroll { - justify-content: right; - } - .justify-center.overflow-x-scroll button:first-of-type { - margin-left: unset; - margin-right: auto; - } - .justify-center.overflow-x-scroll button:last-of-type { - margin-right: unset; - margin-left: auto; - } - #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ - margin-right: unset; - margin-left: 8em; - } - #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ - right: unset; - left: 0; - } - .progressDiv .progress{ - padding: 0 0 0 8px; - text-align: left; - } - #lightboxModal{ - left: unset; - right: 0; - } - .modalPrev, .modalNext{ - border-radius: 3px 0 0 3px; - } - .modalNext { - right: unset; - left: 0; - border-radius: 0 3px 3px 0; - } - #imageARPreview{ - left:unset; - right:0px; - } - #txt2img_skip, #img2img_skip{ - right: unset; - left: 0px; - } - #context-menu{ - box-shadow:-1px 1px 2px #CE6400; - } - .gr-box > div > div > input.gr-text-input{ - right: unset; - left: 0.5em; - } +.extra-networks > div > [id *= '_extra_']{ + margin: 0.3em; } + +.extra-network-cards .nocards{ + margin: 1.25em 0.5em 0.5em 0.5em; +} + +.extra-network-cards .nocards h1{ + font-size: 1.5em; + margin-bottom: 1em; +} + +.extra-network-cards .nocards li{ + margin-left: 0.5em; +} + +.extra-network-cards .card{ + display: inline-block; + margin: 0.5em; + width: 16em; + height: 24em; + box-shadow: 0 0 5px rgba(128, 128, 128, 0.5); + border-radius: 0.2em; + position: relative; + + background-size: auto 100%; + background-position: center; + overflow: hidden; + cursor: pointer; + + background-image: url('./file=html/card-no-preview.png') +} + +.extra-network-cards .card:hover{ + box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35); +} + +.extra-network-cards .card .actions .additional{ + display: none; +} + +.extra-network-cards .card .actions{ + position: absolute; + bottom: 0; + left: 0; + right: 0; + padding: 0.5em; + color: white; + background: rgba(0,0,0,0.5); + box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5); + text-shadow: 0 0 0.2em black; +} + +.extra-network-cards .card .actions:hover{ + box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important; +} + +.extra-network-cards .card .actions .name{ + font-size: 1.7em; + font-weight: bold; + line-break: anywhere; +} + +.extra-network-cards .card .actions:hover .additional{ + display: block; +} + +.extra-network-cards .card ul{ + margin: 0.25em 0 0.75em 0.25em; + cursor: unset; +} + +.extra-network-cards .card ul a{ + cursor: pointer; +} + +.extra-network-cards .card ul a:hover{ + color: red; +} + diff --git a/webui.py b/webui.py index 865a7300..e8dd822a 100644 --- a/webui.py +++ b/webui.py @@ -9,16 +9,18 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook, errors +from modules import import_hook, errors, extra_networks +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path import torch + # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) -from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -84,10 +86,17 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) - shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: try: @@ -209,6 +218,15 @@ def webui(): modules.sd_models.list_models() + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if __name__ == "__main__": if cmd_opts.nowebui: -- cgit v1.2.3 From 6d805b669e86233432f56ee1892d062103abe501 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 09:14:27 +0300 Subject: make CLIP interrogator download original text files if the directory does not exist remove random artist built-in extension (to re-added as a normal extension on demand) remove artists.csv (but what does it mean????????????????????) make interrogate buttons show Loading... when you click them --- README.md | 1 - artists.csv | 3041 -------------------- .../roll-artist/scripts/roll-artist.py | 50 - javascript/hints.js | 1 - modules/api/api.py | 8 - modules/artists.py | 25 - modules/interrogate.py | 55 +- modules/shared.py | 5 - modules/ui.py | 11 +- 9 files changed, 46 insertions(+), 3151 deletions(-) delete mode 100644 artists.csv delete mode 100644 extensions-builtin/roll-artist/scripts/roll-artist.py delete mode 100644 modules/artists.py (limited to 'modules/shared.py') diff --git a/README.md b/README.md index d783fdf0..1ac794e8 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,6 @@ A browser interface based on Gradio library for Stable Diffusion. - Running arbitrary python code from UI (must run with --allow-code to enable) - Mouseover hints for most UI elements - Possible to change defaults/mix/max/step values for UI elements via text config -- Random artist button - Tiling support, a checkbox to create images that can be tiled like textures - Progress bar and live image generation preview - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image diff --git a/artists.csv b/artists.csv deleted file mode 100644 index 1a61ed88..00000000 --- a/artists.csv +++ /dev/null @@ -1,3041 +0,0 @@ -artist,score,category -Peter Max,0.99715996,weird -Roy Lichtenstein,0.98272276,cartoon -Romero Britto,0.9498342,scribbles -Keith Haring,0.9431302,weird -Hiroshige,0.93995106,ukioe -Joan Miró,0.9169429,scribbles -Jean-Michel Basquiat,0.90080947,scribbles -Katsushika Hokusai,0.8887236,ukioe -Paul Klee,0.8868682,scribbles -Marc Chagall,0.8868168,scribbles -Karl Schmidt-Rottluff,0.88444495,scribbles -Howard Hodgkin,0.8808578,scribbles -Jean Metzinger,0.88056004,scribbles -Alma Thomas,0.87658304,weird -Rufino Tamayo,0.8749848,scribbles -Utagawa Hiroshige,0.8728796,ukioe -Chagall,0.8718535,scribbles -Harumi Hironaka,0.86914605,scribbles -Hans Hofmann,0.8686159,scribbles -Kawanabe Kyōsai,0.86612236,ukioe -Andy Warhol,0.8654825,scribbles -Barbara Takenaga,0.86223894,scribbles -Tatsuro Kiuchi,0.8597267,cartoon -Vincent Van Gogh,0.85538065,scribbles -Wassily Kandinsky,0.85490596,scribbles -Georges Seurat,0.8534801,scribbles -Karel Appel,0.8529153,scribbles -Sonia Delaunay,0.8506156,scribbles -Hokusai,0.85046995,ukioe -Eduardo Kobra,0.85036755,weird -Fra Angelico,0.84984255,fineart -Milton Avery,0.849746,scribbles -David Hockney,0.8496144,scribbles -Hiroshi Nagai,0.847129,cartoon -Aristarkh Lentulov,0.846537,scribbles -Lyonel Feininger,0.84573764,scribbles -Mary Blair,0.845709,scribbles -Ellsworth Kelly,0.8455428,scribbles -Jun Kaneko,0.8448367,scribbles -Roz Chast,0.8432013,weird -Ida Rentoul Outhwaite,0.84275174,scribbles -Robert Motherwell,0.8409468,scribbles -Garry Winogrand,0.83994275,black-white -Andrei Rublev,0.83950496,fineart -Alexander Calder,0.83832693,scribbles -Tomokazu Matsuyama,0.8376121,scribbles -August Macke,0.8362022,scribbles -Kazimir Malevich,0.8356527,scribbles -Richard Scarry,0.83554685,scribbles -Victor Vasarely,0.8335438,scribbles -Kitagawa Utamaro,0.83333457,ukioe -Matt Bors,0.83252287,scribbles -Emil Nolde,0.8323225,scribbles -Patrick Caulfield,0.8322225,scribbles -Charles Blackman,0.83200824,scribbles -Peter Doig,0.83111644,scribbles -Alexej von Jawlensky,0.8308932,scribbles -Rumiko Takahashi,0.8301817,anime -Eileen Agar,0.82945526,scribbles -Ernst Ludwig Kirchner,0.82756275,scribbles -Nicolas Delort,0.8261329,scribbles -Marsden Hartley,0.8250993,scribbles -Keith Negley,0.8212553,scribbles -Jamini Roy,0.8212199,scribbles -Quentin Blake,0.82115215,scribbles -Andy Kehoe,0.82063186,cartoon -George barbier,0.82046914,fineart -Frans Masereel,0.81997275,scribbles -Umberto Boccioni,0.81921184,scribbles -Conrad Roset,0.8190752,cartoon -Paul Ranson,0.81903255,scribbles -Yayoi Kusama,0.81886625,weird -Tomi Ungerer,0.81848705,scribbles -Saul Steinberg,0.81778854,scribbles -Jon Klassen,0.81773067,scribbles -W.W. Denslow,0.81708044,fineart -Helen Frankenthaler,0.81704986,scribbles -Jean Jullien,0.816437,scribbles -Brett Whiteley,0.81601924,scribbles -Giotto Di Bondone,0.81427747,fineart -Takashi Murakami,0.81338763,weird -Howard Finster,0.81333554,scribbles -Eduardo Paolozzi,0.81312317,scribbles -Charles Rennie Mackintosh,0.81297064,scribbles -Brandon Mably,0.8128239,weird -Rebecca Louise Law,0.81214285,weird -Victo Ngai,0.81195843,cartoon -Hanabusa Itchō II,0.81187993,ukioe -Edmund Dulac,0.81104875,scribbles -Ben Shahn,0.8104582,scribbles -Howard Arkley,0.8103746,scribbles -Wilfredo Lam,0.8096211,scribbles -Michael Deforge,0.8095954,scribbles -John Hoyland,0.8094592,fineart -Francesco Clemente,0.8090387,scribbles -Leonetto Cappiello,0.8087691,scribbles -Norman Ackroyd,0.80788493,scribbles -Bhupen Khakhar,0.8077607,scribbles -Jeremiah Ketner,0.8075384,cartoon -Chris Ofili,0.8073793,scribbles -Banksy,0.80695426,scribbles -Tom Whalen,0.805867,scribbles -Ernst Wilhelm Nay,0.805295,scribbles -Henri Rousseau,0.8049866,scribbles -Kunisada,0.80493814,ukioe -Naoko Takeuchi,0.80482674,anime -Kaethe Butcher,0.80406916,scribbles -Hasui Kawase,0.8040483,ukioe -Alvin Langdon Coburn,0.8035004,black-white -Stanley Donwood,0.8033054,scribbles -Agnes Martin,0.8028028,scribbles -Osamu Tezuka,0.8005524,cartoon -Frank Stella,0.80049455,scribbles -Dale Chihuly,0.79982775,digipa-high-impact -Evgeni Gordiets,0.79967916,scribbles -Janek Sedlar,0.7993992,fineart -Alasdair Gray,0.7992301,scribbles -Yasuo Kuniyoshi,0.79870003,ukioe -Edward Gorey,0.7984938,scribbles -Johannes Itten,0.798481,scribbles -Cuno Amiet,0.7979497,scribbles -M.C. Escher,0.7976657,scribbles -Albert Irvin,0.79688835,scribbles -Jack Gaughan,0.79443675,scribbles -Ravi Zupa,0.7939542,scribbles -Kay Nielsen,0.79385525,scribbles -Agnolo Gaddi,0.79369193,fineart -Alessandro Gottardo,0.79321593,scribbles -Paul Laffoley,0.79196846,scribbles -Giovanni Battista Piranesi,0.79111177,fineart -Adrian Tomine,0.79109013,scribbles -Adolph Gottlieb,0.79061794,scribbles -Milton Caniff,0.7905358,cartoon -Philip Guston,0.78994095,scribbles -Debbie Criswell,0.7895031,cartoon -Alice Pasquini,0.78949904,cartoon -Johannes Vermeer,0.78931487,fineart -Lisa Frank,0.7892591,cartoon -Patrick Heron,0.78889126,scribbles -Mikhail Nesterov,0.78814346,fineart -Cézanne,0.7879481,scribbles -Tristan Eaton,0.787513,scribbles -Jillian Tamaki,0.7868066,scribbles -Takato Yamamoto,0.78460765,ukioe -Martiros Saryan,0.7844924,scribbles -Emil Orlik,0.7842625,scribbles -Armand Guillaumin,0.7840431,scribbles -Jane Newland,0.7837676,scribbles -Paul Cézanne,0.78368753,scribbles -Tove Jansson,0.78356475,scribbles -Guido Crepax,0.7835321,cartoon -OSGEMEOS,0.7829088,weird -Albert Watson,0.48901254,digipa-med-impact -Emory Douglas,0.78179604,scribbles -Chris Van Allsburg,0.66413003,fineart -Ohara Koson,0.78132576,ukioe -Nicolas de Stael,0.7802779,scribbles -Aubrey Beardsley,0.77970016,scribbles -Hishikawa Moronobu,0.7794119,ukioe -Alfred Wallis,0.77926695,scribbles -Friedensreich Hundertwasser,0.7791805,scribbles -Eyvind Earle,0.7788089,scribbles -Giotto,0.7785216,fineart -Simone Martini,0.77843,fineart -Ivan Bilibin,0.77720606,fineart -Karl Blossfeldt,0.77652574,black-white -Duy Huynh,0.77634746,scribbles -Giovanni da Udina,0.7763063,fineart -Henri-Edmond Cross,0.7762994,fineart -Barry McGee,0.77618384,scribbles -William Kentridge,0.77615225,scribbles -Alexander Archipenko,0.7759824,scribbles -Jaume Plensa,0.7756799,weird -Bill Jacklin,0.77504414,fineart -Alberto Vargas,0.7747376,cartoon -Jean Dubuffet,0.7744374,scribbles -Eugène Grasset,0.7741958,fineart -Arthur Rackham,0.77418125,fineart -Yves Tanguy,0.77380997,scribbles -Elsa Beskow,0.7736908,fineart -Georgia O’Keeffe,0.77368987,scribbles -Georgia O'Keeffe,0.77368987,scribbles -Henri Cartier-Bresson,0.7735415,black-white -Andrea del Verrocchio,0.77307427,fineart -Mark Rothko,0.77294236,scribbles -Bruce Gilden,0.7256681,black-white -Gino Severini,0.77247965,scribbles -Delphin Enjolras,0.5594248,fineart -Alena Aenami,0.77210015,cartoon -Ed Freeman,0.42526615,digipa-low-impact -Apollonia Saintclair,0.7718383,anime -László Moholy-Nagy,0.771497,scribbles -Louis Glackens,0.7713224,fineart -Fang Lijun,0.77097225,fineart -Alfred Kubin,0.74409986,fineart -David Wojnarowicz,0.7705802,scribbles -Tara McPherson,0.77023256,scribbles -Gustav Doré,0.7367536,fineart -Patricia Polacco,0.7696109,scribbles -Norman Bluhm,0.7692634,fineart -Elizabeth Gadd,0.7691194,digipa-high-impact -Gabriele Münter,0.7690926,scribbles -David Inshaw,0.76905304,scribbles -Maurice Sendak,0.7690118,cartoon -Harry Clarke,0.7688428,cartoon -Howardena Pindell,0.7686921,n -Jamie Hewlett,0.7680373,scribbles -Steve Ditko,0.76725733,scribbles -Annie Soudain,0.7671485,scribbles -Albert Gleizes,0.76658314,scribbles -Henry Fuseli,0.69147265,fineart -Alain Laboile,0.67634284,c -Albrecht Altdorfer,0.7663378,fineart -Jack Butler Yeats,0.7661406,fineart -Yue Minjun,0.76583517,scribbles -Art Spiegelman,0.7656343,scribbles -Grete Stern,0.7656276,fineart -Mordecai Ardon,0.7648692,scribbles -Joel Sternfeld,0.76456416,digipa-high-impact -Milton Glaser,0.7641823,scribbles -Eishōsai Chōki,0.7639659,scribbles -Domenico Ghirlandaio,0.76372653,fineart -Alex Timmermans,0.64443207,digipa-high-impact -Andreas Vesalius,0.763446,fineart -Bruce McLean,0.76335883,scribbles -Jacob Lawrence,0.76330304,scribbles -Alex Katz,0.76317835,scribbles -Henri de Toulouse-Lautrec,0.76268333,scribbles -Franz Sedlacek,0.762062,scribbles -Paul Lehr,0.70854837,cartoon -Nicholas Roerich,0.76117516,scribbles -Henri Matisse,0.76110923,scribbles -Colin McCahon,0.76086944,scribbles -Max Dupain,0.6661642,black-white -Stephen Gammell,0.74001735,weird -Alberto Giacometti,0.7596302,scribbles -Goyō Hashiguchi,0.7595048,ukioe -Gustave Doré,0.7018832,fineart -Butcher Billy,0.7593378,cartoon -Pieter de Hooch,0.75916564,fineart -Gaetano Pesce,0.75906265,scribbles -Winsor McCay,0.7589382,scribbles -Claude Cahun,0.7588153,weird -Roger Ballen,0.64683115,black-white -Ellen Gallagher,0.758621,scribbles -Anton Corbijn,0.5550669,digipa-high-impact -Margaret Macdonald Mackintosh,0.75781375,fineart -Franz Kline,0.7576461,scribbles -Cimabue,0.75720495,fineart -André Kertész,0.7319392,black-white -Hans Hartung,0.75718236,scribbles -J. J. Grandville,0.7321584,fineart -David Octavius Hill,0.6333561,digipa-high-impact -teamLab,0.7566472,digipa-high-impact -Paul Gauguin,0.75635266,scribbles -Etel Adnan,0.75631833,scribbles -Barbara Kruger,0.7562784,scribbles -Franz Marc,0.75538874,scribbles -Saul Bass,0.75496316,scribbles -El Lissitzky,0.7549487,scribbles -Thomas Moran,0.6507399,fineart -Claude Monet,0.7541377,fineart -David Young Cameron,0.7541016,scribbles -W. Heath Robinson,0.75374347,cartoon -Yves Klein,0.7536262,fineart -Albert Pinkham Ryder,0.7338848,fineart -Elizabeth Shippen Green,0.7533686,fineart -Robert Stivers,0.5516287,fineart -Emily Kame Kngwarreye,0.7532016,weird -Charline von Heyl,0.753142,scribbles -Frida Kahlo,0.75303876,scribbles -Amy Sillman,0.752921,scribbles -Emperor Huizong of Song,0.7525214,ukioe -Edward Burne-Jones,0.75220466,fineart -Brett Weston,0.6891357,black-white -Charles E. Burchfield,0.75174403,scribbles -Hishida Shunsō,0.751617,fareast -Elaine de Kooning,0.7514996,scribbles -Gary Panter,0.7514598,scribbles -Frederick Hammersley,0.7514268,scribbles -Gustave Dore,0.6735896,fineart -Ephraim Moses Lilien,0.7510494,fineart -Hannah Hoch,0.7509496,scribbles -Shepard Fairey,0.7508583,scribbles -Richard Burlet,0.7506659,scribbles -Bill Brandt,0.6833408,black-white -Herbert List,0.68455493,black-white -Joseph Cornell,0.75023884,nudity -Nathan Wirth,0.6436741,black-white -John Kenn Mortensen,0.74758303,anime -Andre De Dienes,0.5683014,digipa-high-impact -Albert Robida,0.7485741,cartoon -Shintaro Kago,0.7484431,anime -Sidney Nolan,0.74809414,scribbles -Patrice Murciano,0.61973965,fineart -Brian Stelfreeze,0.7478351,scribbles -Francisco De Goya,0.6954584,fineart -William Morris,0.7478111,fineart -Honoré Daumier,0.74767774,scribbles -Hubert Robert,0.6863421,fineart -Marianne von Werefkin,0.7475825,fineart -Edvard Munch,0.74719715,scribbles -Victor Brauner,0.74719006,scribbles -George Inness,0.7470588,fineart -Naoki Urasawa,0.7469665,anime -Kilian Eng,0.7468486,scribbles -Bordalo II,0.7467364,digipa-high-impact -Katsuhiro Otomo,0.746364,anime -Maximilien Luce,0.74609685,fineart -Amy Earles,0.74603415,fineart -Jeanloup Sieff,0.7196009,black-white -William Zorach,0.74574494,scribbles -Pascale Campion,0.74516207,fineart -Dorothy Lathrop,0.74418795,fineart -Sofonisba Anguissola,0.74418664,fineart -Natalia Goncharova,0.74414873,scribbles -August Sander,0.6644566,black-white -Jasper Johns,0.74395454,scribbles -Arthur Dove,0.74383533,scribbles -Darwyn Cooke,0.7435789,scribbles -Leonardo Da Vinci,0.6825216,fineart -Fra Filippo Lippi,0.7433891,fineart -Pierre-Auguste Renoir,0.742464,fineart -Jeff Lemire,0.7422893,scribbles -Al Williamson,0.742113,cartoon -Childe Hassam,0.7418015,fineart -Francisco Goya,0.69522625,fineart -Alphonse Mucha,0.74171394,special -Cleon Peterson,0.74163914,scribbles -J.M.W. Turner,0.65582645,fineart -Walter Crane,0.74146044,fineart -Brassaï,0.6361966,digipa-high-impact -Virgil Finlay,0.74133486,fineart -Fernando Botero,0.7412504,nudity -Ben Nicholson,0.7411573,scribbles -Robert Rauschenberg,0.7410054,fineart -David Wiesner,0.7406237,scribbles -Bartolome Esteban Murillo,0.6933951,fineart -Jean Arp,0.7403873,scribbles -Andre Kertesz,0.7228358,black-white -Simeon Solomon,0.66441345,fineart -Hugh Ferriss,0.72443527,black-white -Agnes Lawrence Pelton,0.73960555,scribbles -Charles Camoin,0.7395686,scribbles -Paul Strand,0.7080332,black-white -Charles Gwathmey,0.7394747,scribbles -Bartolomé Esteban Murillo,0.7011274,fineart -Oskar Kokoschka,0.7392038,scribbles -Bruno Munari,0.73918355,weird -Willem de Kooning,0.73916197,scribbles -Hans Memling,0.7387886,fineart -Chris Mars,0.5861489,digipa-high-impact -Hiroshi Yoshida,0.73787534,ukioe -Hundertwasser,0.7377672,fineart -David Bowie,0.73773724,weird -Ettore Sottsass,0.7376095,digipa-high-impact -Antanas Sutkus,0.7369492,black-white -Leonora Carrington,0.73726475,scribbles -Hieronymus Bosch,0.7369955,scribbles -A. J. Casson,0.73666203,scribbles -Chaim Soutine,0.73662066,scribbles -Artur Bordalo,0.7364549,weird -Thomas Allom,0.68792284,fineart -Louis Comfort Tiffany,0.7363504,fineart -Philippe Druillet,0.7363382,cartoon -Jan Van Eyck,0.7360621,fineart -Sandro Botticelli,0.7359395,fineart -Hieronim Bosch,0.7359308,scribbles -Everett Shinn,0.7355817,fineart -Camille Corot,0.7355603,fineart -Nick Sharratt,0.73470485,scribbles -Fernand Léger,0.7079839,scribbles -Robert S. Duncanson,0.7346282,fineart -Hieronymous Bosch,0.73453265,scribbles -Charles Addams,0.7344034,scribbles -Studio Ghibli,0.73439026,anime -Archibald Motley,0.7343683,scribbles -Anton Fadeev,0.73433846,cartoon -Uemura Shoen,0.7342118,ukioe -Ando Fuchs,0.73406494,black-white -Jessie Willcox Smith,0.73398125,fineart -Alex Garant,0.7333658,scribbles -Lawren Harris,0.73331416,scribbles -Anne Truitt,0.73297834,scribbles -Richard Lindner,0.7328564,scribbles -Sailor Moon,0.73281246,anime -Bridget Bate Tichenor,0.73274165,scribbles -Ralph Steadman,0.7325864,scribbles -Annibale Carracci,0.73251307,fineart -Dürer,0.7324789,fineart -Abigail Larson,0.7319012,cartoon -Bill Traylor,0.73189163,scribbles -Louis Rhead,0.7318623,fineart -David Burliuk,0.731803,scribbles -Camille Pissarro,0.73172396,fineart -Catrin Welz-Stein,0.73117495,scribbles -William Etty,0.6497544,nudity -Pierre Bonnard,0.7310132,scribbles -Benoit B. Mandelbrot,0.5033001,digipa-med-impact -Théodore Géricault,0.692039,fineart -Andy Goldsworthy,0.7307565,digipa-high-impact -Alfred Sisley,0.7306032,fineart -Charles-Francois Daubigny,0.73057353,fineart -Karel Thole,0.7305395,cartoon -Andre Derain,0.73050404,scribbles -Larry Poons,0.73023695,fineart -Beauford Delaney,0.72999024,scribbles -Ruth Bernhard,0.72990334,black-white -David Alfaro Siqueiros,0.7297947,scribbles -Gaugin,0.729636,fineart -Carl Larsson,0.7296195,cartoon -Albrecht Dürer,0.72946966,fineart -Henri De Toulouse Lautrec,0.7294263,cartoon -Shotaro Ishinomori,0.7292093,anime -Hope Gangloff,0.729082,scribbles -Vivian Maier,0.72897506,digipa-high-impact -Alex Andreev,0.6442978,digipa-high-impact -Julie Blackmon,0.72862685,c -Arthur Melville,0.7286146,fineart -Henri Michaux,0.599607,fineart -William Steig,0.7283096,scribbles -Octavio Ocampo,0.72814554,scribbles -Cy Twombly,0.72814107,scribbles -Guy Denning,0.67375445,fineart -Maxfield Parrish,0.7280283,fineart -Randolph Caldecott,0.7279564,fineart -Duccio,0.72795,fineart -Ray Donley,0.5837457,fineart -Hiroshi Sugimoto,0.6497892,digipa-high-impact -Daniela Uhlig,0.4691466,special -Go Nagai,0.72770613,anime -Carlo Crivelli,0.72764605,fineart -Helmut Newton,0.44433144,digipa-low-impact -Josef Albers,0.7061394,scribbles -Henry Moret,0.7274567,fineart -André Masson,0.727404,scribbles -Henri Fantin Latour,0.72732764,fineart -Theo van Rysselberghe,0.7272843,fineart -John Wayne Gacy,0.72686327,scribbles -Carlos Schwabe,0.7267612,fineart -Herbert Bayer,0.7094297,scribbles -Domenichino,0.72667265,fineart -Liam Wong,0.7262276,special -George Caleb Bingham,0.7262154,digipa-high-impact -Gigadō Ashiyuki,0.7261864,fineart -Chaïm Soutine,0.72603923,scribbles -Ary Scheffer,0.64913243,fineart -Rockwell Kent,0.7257272,scribbles -Jean-Paul Riopelle,0.72570604,fineart -Ed Mell,0.6637067,cartoon -Ismail Inceoglu,0.72561014,special -Edgar Degas,0.72538006,fineart -Giorgione,0.7252798,fineart -Charles-François Daubigny,0.7252482,fineart -Arthur Lismer,0.7251765,scribbles -Aaron Siskind,0.4852289,digipa-med-impact -Arkhip Kuindzhi,0.7249981,fineart -Joseph Mallord William Turner,0.6834406,fineart -Dante Gabriel Rossetti,0.7244541,fineart -Ernst Haeckel,0.6660129,fineart -Rebecca Guay,0.72439146,cartoon -Anthony Gerace,0.636678,digipa-high-impact -Martin Kippenberger,0.72418386,scribbles -Diego Giacometti,0.72415763,scribbles -Dmitry Kustanovich,0.7241322,cartoon -Dora Carrington,0.7239633,scribbles -Shusei Nagaoko,0.7238965,anime -Odilon Redon,0.72381747,scribbles -Shohei Otomo,0.7132803,nudity -Barnett Newman,0.7236389,scribbles -Jean Fouquet,0.7235963,fineart -Gustav Klimt,0.72356784,nudity -Francisco Josè de Goya,0.6589663,fineart -Bonnard Pierre,0.72309464,nudity -Brooke Shaden,0.61281693,digipa-high-impact -Mao Hamaguchi,0.7228292,scribbles -Frederick Edwin Church,0.64416,fineart -Asher Brown Durand,0.72264796,fineart -George Baselitz,0.7223453,scribbles -Sam Bosma,0.7223237,fineart -Asaf Hanuka,0.72222745,scribbles -David Teniers the Younger,0.7221168,fineart -Nicola Samori,0.68747556,nudity -Claude Lorrain,0.7217102,fineart -Hermenegildo Anglada Camarasa,0.7214374,nudity -Pablo Picasso,0.72142905,scribbles -Howard Chaykin,0.7213998,cartoon -Ferdinand Hodler,0.7213758,nudity -Farel Dalrymple,0.7213298,fineart -Lyubov Popova,0.7213024,scribbles -Albin Egger-Lienz,0.72120845,fineart -Geertgen tot Sint Jans,0.72107565,fineart -Kate Greenaway,0.72069687,fineart -Louise Bourgeois,0.7206516,fineart -Miriam Schapiro,0.72026414,fineart -Pieter Claesz,0.7200939,fineart -George B. Bridgman,0.5592567,fineart -Piet Mondrian,0.71990657,scribbles -Michelangelo Merisi Da Caravaggio,0.7094674,fineart -Marie Spartali Stillman,0.71986604,fineart -Gertrude Abercrombie,0.7196962,scribbles -Louis Icart,0.7195913,fineart -David Driskell,0.719564,scribbles -Paula Modersohn-Becker,0.7193769,scribbles -George Hurrell,0.57496595,digipa-high-impact -Andrea Mantegna,0.7190254,fineart -Silvestro Lega,0.71891177,fineart -Junji Ito,0.7188978,anime -Jacob Hashimoto,0.7186867,digipa-high-impact -Benjamin West,0.6642946,fineart -David Teniers the Elder,0.7181293,fineart -Roberto Matta,0.71808386,fineart -Chiho Aoshima,0.71801454,anime -Amedeo Modigliani,0.71788836,scribbles -Raja Ravi Varma,0.71788085,fineart -Roberto Ferri,0.538221,nudity -Winslow Homer,0.7176876,fineart -Horace Vernet,0.65729,fineart -Lucas Cranach the Elder,0.71738195,fineart -Godfried Schalcken,0.625893,fineart -Affandi,0.7170285,nudity -Diane Arbus,0.655138,digipa-high-impact -Joseph Ducreux,0.65247905,digipa-high-impact -Berthe Morisot,0.7165984,fineart -Hilma af Klint,0.71643853,scribbles -Filippino Lippi,0.7163017,fineart -Leonid Afremov,0.7163005,fineart -Chris Ware,0.71628594,scribbles -Marius Borgeaud,0.7162446,scribbles -M.W. Kaluta,0.71612585,cartoon -Govert Flinck,0.68975246,fineart -Charles Demuth,0.71605396,scribbles -Coles Phillips,0.7158309,scribbles -Oskar Fischinger,0.6721027,digipa-high-impact -David Teniers III,0.71569765,fineart -Jean Delville,0.7156771,fineart -Antonio Saura,0.7155949,scribbles -Bridget Riley,0.7155669,fineart -Gordon Parks,0.5759978,digipa-high-impact -Anselm Kiefer,0.71514887,scribbles -Remedios Varo,0.7150927,weird -Franz Hegi,0.71495223,scribbles -Kati Horna,0.71486115,black-white -Arshile Gorky,0.71459055,scribbles -David LaChapelle,0.7144903,scribbles -Fritz von Dardel,0.71446383,scribbles -Edward Ruscha,0.71438885,fineart -Blanche Hoschedé Monet,0.7143073,fineart -Alexandre Calame,0.5735474,fineart -Sean Scully,0.714154,fineart -Alexandre Benois,0.7141515,fineart -Sally Mann,0.6534312,black-white -Thomas Eakins,0.7141104,fineart -Arnold Böcklin,0.71407956,fineart -Alfonse Mucha,0.7139052,special -Damien Hirst,0.7136273,scribbles -Lee Krasner,0.71362555,scribbles -Dorothea Lange,0.71361613,black-white -Juan Gris,0.7132987,scribbles -Bernardo Bellotto,0.70720065,fineart -John Martin,0.5376847,fineart -Harriet Backer,0.7131594,fineart -Arnold Newman,0.5736342,digipa-high-impact -Gjon Mili,0.46520913,digipa-low-impact -Asger Jorn,0.7129575,scribbles -Chesley Bonestell,0.6063316,fineart -Agostino Carracci,0.7128167,fineart -Peter Wileman,0.71271706,cartoon -Chen Hongshou,0.71268153,ukioe -Catherine Hyde,0.71266896,scribbles -Andrea Pozzo,0.626546,fineart -Kitty Lange Kielland,0.7125735,fineart -Cornelis Saftleven,0.6684047,fineart -Félix Vallotton,0.71237606,fineart -Albrecht Durer,0.7122327,fineart -Jackson Pollock,0.71222305,scribbles -John Bratby,0.7122171,scribbles -Beksinski,0.71218586,fineart -James Thomas Watts,0.5959548,fineart -Konstantin Korovin,0.71188873,fineart -Gustave Caillebotte,0.71181154,fineart -Dean Ellis,0.50233585,fineart -Friedrich von Amerling,0.6420181,fineart -Christopher Balaskas,0.67935324,special -Alexander Rodchenko,0.67415404,scribbles -Alfred Cheney Johnston,0.6647291,fineart -Mikalojus Konstantinas Ciurlionis,0.710677,scribbles -Jean-Antoine Watteau,0.71061164,fineart -Paul Delvaux,0.7105914,scribbles -Francesco del Cossa,0.7104901,nudity -Isaac Cordal,0.71046066,weird -Hikari Shimoda,0.7104546,weird -François Boucher,0.67153126,fineart -Akos Major,0.7103802,digipa-high-impact -Bernard Buffet,0.7103491,cartoon -Brandon Woelfel,0.6727086,digipa-high-impact -Edouard Manet,0.7101296,fineart -Auguste Herbin,0.6866145,scribbles -Eugene Delacroix,0.70995826,fineart -L. Birge Harrison,0.70989627,fineart -Howard Pyle,0.70979863,fineart -Diane Dillon,0.70968723,scribbles -Hans Erni,0.7096618,scribbles -Richard Diebenkorn,0.7096184,scribbles -Thomas Gainsborough,0.6759419,fineart -Maria Sibylla Merian,0.7093275,fineart -François Joseph Heim,0.6175854,fineart -E. H. Shepard,0.7091189,cartoon -Hsiao-Ron Cheng,0.7090618,scribbles -Canaletto,0.7090392,fineart -John Atkinson Grimshaw,0.7087531,fineart -Giovanni Battista Tiepolo,0.6754107,fineart -Cornelis van Poelenburgh,0.69821274,fineart -Raina Telgemeier,0.70846486,scribbles -Francesco Hayez,0.6960006,fineart -Gilbert Stuart,0.659772,fineart -Konstantin Yuon,0.7081486,fineart -Antonello da Messina,0.70806944,fineart -Austin Osman Spare,0.7079903,fineart -James Ensor,0.70781446,scribbles -Claude Bonin-Pissarro,0.70739406,fineart -Mikhail Vrubel,0.70738363,fineart -Angelica Kauffman,0.6748828,fineart -Viktor Vasnetsov,0.7072422,fineart -Alphonse Osbert,0.70724136,fineart -Tsutomu Nihei,0.7070495,anime -Harvey Quaytman,0.63613266,fineart -Jamie Hawkesworth,0.706914,digipa-high-impact -Francesco Guardi,0.70682615,fineart -Jean-Honoré Fragonard,0.6518248,fineart -Brice Marden,0.70673287,digipa-high-impact -Charles-Amédée-Philippe van Loo,0.6725916,fineart -Mati Klarwein,0.7066092,n -Gerard ter Borch,0.706589,fineart -Dan Hillier,0.48966256,digipa-med-impact -Federico Barocci,0.682664,fineart -Henri Le Sidaner,0.70637953,fineart -Olivier Bonhomme,0.7063748,scribbles -Edward Weston,0.7061382,black-white -Giovanni Paolo Cavagna,0.6840265,fineart -Germaine Krull,0.6621777,black-white -Hans Holbein the Younger,0.70590156,fineart -François Bocion,0.6272365,fineart -Georg Baselitz,0.7053314,scribbles -Caravaggio,0.7050303,fineart -Anne Rothenstein,0.70502245,scribbles -Wadim Kashin,0.43714935,digipa-low-impact -Heinrich Lefler,0.7048054,fineart -Jacob van Ruisdael,0.7047918,fineart -Bartholomeus van Bassen,0.6676872,fineart -Jeffrey Smith art,0.56750107,fineart -Anne Packard,0.7046703,weird -Jean-François Millet,0.7045456,fineart -Andrey Remnev,0.7041204,digipa-high-impact -Fujiwara Takanobu,0.70410216,ukioe -Elliott Erwitt,0.69950557,black-white -Fern Coppedge,0.7036215,fineart -Bartholomeus van der Helst,0.66411966,fineart -Rembrandt Van Rijn,0.6979987,fineart -Rene Magritte,0.703457,scribbles -Aelbert Cuyp,0.7033657,fineart -Gerda Wegener,0.70319015,scribbles -Graham Sutherland,0.7031714,scribbles -Gerrit Dou,0.7029986,fineart -August Friedrich Schenck,0.6801586,fineart -George Herriman,0.7028568,scribbles -Stanisław Szukalski,0.6903354,fineart -Slim Aarons,0.70222545,digipa-high-impact -Ernst Thoms,0.70221686,fineart -Louis Wain,0.702186,fineart -Artemisia Gentileschi,0.70198226,fineart -Eugène Delacroix,0.70155394,fineart -Peter Bagge,0.70127463,scribbles -Jeffrey Catherine Jones,0.7012148,cartoon -Eugène Carrière,0.65272695,fineart -Alexander Millar,0.7011144,scribbles -Nobuyoshi Araki,0.70108867,fareast -Tintoretto,0.6702795,fineart -André Derain,0.7009005,scribbles -Charles Maurice Detmold,0.70079994,fineart -Francisco de Zurbarán,0.7007234,fineart -Laurie Greasley,0.70072114,cartoon -Lynda Benglis,0.7006948,digipa-high-impact -Cecil Beaton,0.66362655,black-white -Gustaf Tenggren,0.7006041,cartoon -Abdur Rahman Chughtai,0.7004994,ukioe -Constantin Brancusi,0.7004367,scribbles -Mikhail Larionov,0.7004066,fineart -Jan van Kessel the Elder,0.70040506,fineart -Chantal Joffe,0.70036674,scribbles -Charles-André van Loo,0.6830367,fineart -Reginald Marsh,0.6301042,fineart -Elsa Bleda,0.70005083,digipa-high-impact -Peter Paul Rubens,0.65745676,fineart -Eugène Boudin,0.70001304,fineart -Charles Willson Peale,0.66907954,fineart -Brian Mashburn,0.63395154,digipa-high-impact -Barkley L. Hendricks,0.69986427,n -Yoshiyuki Tomino,0.6998095,anime -Guido Reni,0.6416875,fineart -Lynd Ward,0.69958556,fineart -John Constable,0.6907788,fineart -František Kupka,0.6993329,fineart -Pieter Bruegel The Elder,0.6992879,scribbles -Benjamin Gerritsz Cuyp,0.6992173,fineart -Nicolas Mignard,0.6988214,fineart -Augustus Edwin Mulready,0.6482165,fineart -Andrea del Sarto,0.698532,fineart -Edward Steichen,0.69837445,black-white -James Abbott McNeill Whistler,0.69836813,fineart -Alphonse Legros,0.6983243,fineart -Ivan Aivazovsky,0.64588225,fineart -Giovanni Francesco Barbieri,0.6981316,fineart -Grace Cossington Smith,0.69811064,fineart -Bert Stern,0.53411555,scribbles -Mary Cassatt,0.6980135,fineart -Jules Bastien-Lepage,0.69796044,fineart -Max Ernst,0.69777006,fineart -Kentaro Miura,0.697743,anime -Georges Rouault,0.69758564,scribbles -Josephine Wall,0.6973667,fineart -Anne-Louis Girodet,0.58104825,nudity -Bert Hardy,0.6972966,black-white -Adriaen van de Velde,0.69716156,fineart -Andreas Achenbach,0.61108655,fineart -Hayv Kahraman,0.69705284,fineart -Beatrix Potter,0.6969851,fineart -Elmer Bischoff,0.6968948,fineart -Cornelis de Heem,0.6968436,fineart -Inio Asano,0.6965007,anime -Alfred Henry Maurer,0.6964837,fineart -Gottfried Helnwein,0.6962953,digipa-high-impact -Paul Barson,0.54196984,digipa-high-impact -Roger de La Fresnaye,0.69620967,fineart -Abraham Mignon,0.60605425,fineart -Albert Bloch,0.69573116,nudity -Charles Dana Gibson,0.67155975,fineart -Alexandre-Évariste Fragonard,0.6507174,fineart -Ernst Fuchs,0.6953538,nudity -Alfredo Jaar,0.6952965,digipa-high-impact -Judy Chicago,0.6952246,weird -Frans van Mieris the Younger,0.6951849,fineart -Aertgen van Leyden,0.6951305,fineart -Emily Carr,0.69512105,fineart -Frances MacDonald,0.6950408,scribbles -Hannah Höch,0.69495845,scribbles -Gillis Rombouts,0.58770025,fineart -Käthe Kollwitz,0.6947756,fineart -Barbara Stauffacher Solomon,0.6920825,fineart -Georges Lacombe,0.6944455,fineart -Gwen John,0.6944161,fineart -Terada Katsuya,0.6944026,cartoon -James Gillray,0.6871335,fineart -Robert Crumb,0.69420326,fineart -Bruce Pennington,0.6545669,fineart -David Firth,0.69400465,scribbles -Arthur Boyd,0.69399726,fineart -Antonin Artaud,0.67321455,fineart -Giuseppe Arcimboldo,0.6937329,fineart -Jim Mahfood,0.6936606,cartoon -Ossip Zadkine,0.6494374,scribbles -Atelier Olschinsky,0.69349927,fineart -Carl Frederik von Breda,0.57274634,fineart -Ken Sugimori,0.6932626,anime -Chris Friel,0.5399168,fineart -Andrew Macara,0.69307995,fineart -Alexander Jansson,0.69298327,scribbles -Anne Brigman,0.6865817,black-white -George Ault,0.66756654,fineart -Arkhyp Kuindzhi,0.6928072,digipa-high-impact -Emiliano Ponzi,0.69278395,scribbles -William Holman Hunt,0.6927663,fineart -Tamara Lempicka,0.6386007,scribbles -Mark Ryden,0.69259655,fineart -Giovanni Paolo Pannini,0.6802902,fineart -Carl Barks,0.6923666,cartoon -Fritz Bultman,0.6318746,fineart -Salomon van Ruysdael,0.690313,fineart -Carrie Mae Weems,0.6645416,n -Agostino Arrivabene,0.61166185,fineart -Gustave Boulanger,0.655797,fineart -Henry Justice Ford,0.51214355,fareast -Bernardo Strozzi,0.63510317,fineart -André Lhote,0.68718815,scribbles -Paul Corfield,0.6915611,scribbles -Gifford Beal,0.6914777,fineart -Hirohiko Araki,0.6914078,anime -Emil Carlsen,0.691326,fineart -Frans van Mieris the Elder,0.6912799,fineart -Simon Stalenhag,0.6912775,special -Henry van de Velde,0.64838886,fineart -Eleanor Fortescue-Brickdale,0.6909729,fineart -Thomas W Schaller,0.69093937,special -NHK Animation,0.6907677,cartoon -Euan Uglow,0.69060403,scribbles -Hendrick Goltzius,0.69058937,fineart -William Blake,0.69038224,fineart -Vito Acconci,0.58409876,digipa-high-impact -Billy Childish,0.6902057,scribbles -Ben Quilty,0.6875855,fineart -Mark Briscoe,0.69010437,fineart -Adriaen van de Venne,0.6899867,fineart -Alasdair McLellan,0.6898454,digipa-high-impact -Ed Paschke,0.68974686,scribbles -Guy Rose,0.68960273,fineart -Barbara Hepworth,0.68958247,fineart -Edward Henry Potthast,0.6895703,fineart -Francis Bacon,0.6895397,scribbles -Pawel Kuczynski,0.6894536,fineart -Bjarke Ingels,0.68933153,digipa-high-impact -Henry Ossawa Tanner,0.68932164,fineart -Alessandro Allori,0.6892961,fineart -Abraham van Calraet,0.63841593,fineart -Egon Schiele,0.6891415,scribbles -Tim Doyle,0.5474768,digipa-high-impact -Grandma Moses,0.6890782,fineart -John Frederick Kensett,0.61981744,fineart -Giacomo Balla,0.68893707,fineart -Jamie Baldridge,0.6546651,digipa-high-impact -Max Beckmann,0.6884731,scribbles -Cornelis van Haarlem,0.6677613,fineart -Edward Hopper,0.6884258,special -Barkley Hendricks,0.6883637,n -Patrick Dougherty,0.688321,digipa-high-impact -Karol Bak,0.6367705,fineart -Pierre Puvis de Chavannes,0.6880703,fineart -Antoni Tàpies,0.685689,fineart -Alexander Nasmyth,0.57695735,fineart -Laurent Grasso,0.5793272,fineart -Camille Walala,0.6076875,digipa-high-impact -Fairfield Porter,0.68790644,fineart -Alex Colville,0.68787855,fineart -Herb Ritts,0.51471305,scribbles -Gerhard Munthe,0.687658,fineart -Susan Seddon Boulet,0.68762136,scribbles -Liu Ye,0.68760437,fineart -Robert Antoine Pinchon,0.68744636,fineart -Fujiwara Nobuzane,0.6873439,fineart -Frederick Carl Frieseke,0.6873361,fineart -Aert van der Neer,0.6159286,fineart -Allen Jones,0.6869935,scribbles -Anja Millen,0.6064488,digipa-high-impact -Esaias van de Velde,0.68673944,fineart -Gyoshū Hayami,0.68665624,anime -William Hogarth,0.6720842,fineart -Frederic Church,0.6865637,fineart -Cyril Rolando,0.68644965,cartoon -Frederic Edwin Church,0.6863009,fineart -Thomas Rowlandson,0.66726154,fineart -Joachim Brohm,0.68601763,digipa-high-impact -Cristofano Allori,0.6858083,fineart -Adrianus Eversen,0.58259964,fineart -Richard Dadd,0.68546164,fineart -Ambrosius Bosschaert II,0.6854217,fineart -Paolo Veronese,0.68422073,fineart -Abraham van den Tempel,0.66463804,fineart -Duncan Grant,0.6852565,scribbles -Hendrick Cornelisz. van Vliet,0.6851691,fineart -Geof Darrow,0.6851174,scribbles -Émile Bernard,0.6850957,fineart -Brian Bolland,0.68496394,scribbles -James Gilleard,0.6849431,cartoon -Anton Raphael Mengs,0.6689196,fineart -Augustus Jansson,0.6845705,digipa-high-impact -Hendrik Goltzius,0.6843367,fineart -Domenico Quaglio the Younger,0.65769434,fineart -Cicely Mary Barker,0.6841806,fineart -William Eggleston,0.6840795,digipa-high-impact -David Choe,0.6840449,scribbles -Adam Elsheimer,0.6716068,fineart -Heinrich Danioth,0.5390186,fineart -Franz Stuck,0.6836468,fineart -Bernie Wrightson,0.64101505,fineart -Dorina Costras,0.6835419,fineart -El Greco,0.68343943,fineart -Gatōken Shunshi,0.6833314,anime -Giovanni Bellini,0.67622876,fineart -Aron Wiesenfeld,0.68331146,nudity -Boris Kustodiev,0.68329334,fineart -Alec Soth,0.5597321,digipa-high-impact -Artus Scheiner,0.6313348,fineart -Kelly Vivanco,0.6830933,scribbles -Shaun Tan,0.6830649,fineart -Anthony van Dyck,0.6577681,fineart -Neil Welliver,0.68297863,nudity -Robert McCall,0.68294585,fineart -Sandra Chevrier,0.68284667,scribbles -Yinka Shonibare,0.68256056,n -Arthur Tress,0.6301861,digipa-high-impact -Richard McGuire,0.6820089,scribbles -Anni Albers,0.65708244,digipa-high-impact -Aleksey Savrasov,0.65207493,fineart -Wayne Barlowe,0.6537874,fineart -Giorgio de Chirico,0.6815907,fineart -Ernest Procter,0.6815795,fineart -Adriaen Brouwer,0.6815058,fineart -Ilya Glazunov,0.6813533,fineart -Alison Bechdel,0.68096143,scribbles -Carl Holsoe,0.68082225,fineart -Alfred Edward Chalon,0.6464571,fineart -Gerard David,0.68058,fineart -Basil Blackshaw,0.6805679,fineart -Gerrit Adriaenszoon Berckheyde,0.67340267,fineart -George Hendrik Breitner,0.6804209,fineart -Abraham Bloemaert,0.68036544,fineart -Ferdinand Van Kessel,0.67742276,fineart -Hugo Simberg,0.68031186,fineart -Gaston Bussière,0.665221,fineart -Shawn Coss,0.42407864,digipa-low-impact -Hanabusa Itchō,0.68023074,ukioe -Magnus Enckell,0.6801553,fineart -Gary Larson,0.6801336,scribbles -George Manson,0.68013126,digipa-high-impact -Hayao Miyazaki,0.6800754,anime -Carl Spitzweg,0.66581815,fineart -Ambrosius Holbein,0.6798341,fineart -Domenico Pozzi,0.6434162,fineart -Dorothea Tanning,0.6797955,fineart -Jeannette Guichard-Bunel,0.5251578,digipa-high-impact -Victor Moscoso,0.62962687,fineart -Francis Picabia,0.6795391,scribbles -Charles W. Bartlett,0.67947805,fineart -David A Hardy,0.5554935,fineart -C. R. W. Nevinson,0.67946506,fineart -Man Ray,0.6507145,scribbles -Albert Bierstadt,0.67935765,fineart -Charles Le Brun,0.6758479,fineart -Lovis Corinth,0.67913896,fineart -Herbert Abrams,0.5507507,digipa-high-impact -Giorgio Morandi,0.6789025,fineart -Agnolo Bronzino,0.6787985,fineart -Abraham Pether,0.66922426,fineart -John Bauer,0.6786695,fineart -Arthur Stanley Wilkinson,0.67860866,fineart -Arthur Wardle,0.5510789,fineart -George Romney,0.62868094,fineart -Laurie Lipton,0.5201844,fineart -Mickalene Thomas,0.45433685,digipa-low-impact -Alice Rahon,0.6777824,scribbles -Gustave Van de Woestijne,0.6777346,scribbles -Laurel Burch,0.67766285,fineart -Hendrik Gerritsz Pot,0.67750573,fineart -John William Waterhouse,0.677472,fineart -Conor Harrington,0.5967809,fineart -Gabriel Ba,0.6773366,cartoon -Franz Xaver Winterhalter,0.62229514,fineart -George Cruikshank,0.6473593,fineart -Hyacinthe Rigaud,0.67717785,fineart -Cornelis Claesz van Wieringen,0.6770269,fineart -Adriaen van Outrecht,0.67682564,fineart -Yaacov Agam,0.6767926,fineart -Franz von Lenbach,0.61948,fineart -Clyfford Still,0.67667866,fineart -Alexander Roslin,0.66719526,fineart -Barry Windsor Smith,0.6765375,cartoon -Takeshi Obata,0.67643225,anime -John Harris,0.47712502,fineart -Bruce Davidson,0.6763525,digipa-high-impact -Hendrik Willem Mesdag,0.6762745,fineart -Makoto Shinkai,0.67610705,anime -Andreas Gursky,0.67610145,digipa-high-impact -Mike Winkelmann (Beeple),0.6510196,digipa-high-impact -Gustave Moreau,0.67607844,fineart -Frank Weston Benson,0.6760142,fineart -Eduardo Kingman,0.6759026,fineart -Benjamin Williams Leader,0.5611925,fineart -Hervé Guibert,0.55973417,black-white -Cornelis Dusart,0.6753622,fineart -Amédée Guillemin,0.6752696,fineart -Alessio Albi,0.6752633,digipa-high-impact -Matthias Grünewald,0.6751779,fineart -Fujishima Takeji,0.6751577,anime -Georges Braque,0.67514753,scribbles -John Salminen,0.67498183,fineart -Atey Ghailan,0.674873,scribbles -Giovanni Antonio Galli,0.657484,fineart -Julie Mehretu,0.6748382,fineart -Jean Auguste Dominique Ingres,0.6746286,fineart -Francesco Albani,0.6621554,fineart -Anato Finnstark,0.6744919,digipa-high-impact -Giovanni Bernardino Mazzolini,0.64416045,fineart -Antoine Le Nain,0.6233709,fineart -Ford Madox Brown,0.6743224,fineart -Gerhard Richter,0.67426133,fineart -theCHAMBA,0.6742506,cartoon -Edward Julius Detmold,0.67421955,fineart -George Stubbs,0.6209227,fineart -George Tooker,0.6740602,scribbles -Faith Ringgold,0.6739976,scribbles -Giambattista Pittoni,0.5792371,fineart -George Bellows,0.6737008,fineart -Aldus Manutius,0.67366326,fineart -Ambrosius Bosschaert,0.67364097,digipa-high-impact -Michael Parkes,0.6133628,fineart -Hans Bellmer,0.6735973,nudity -Sir James Guthrie,0.67359626,fineart -Charles Spencelayh,0.67356884,fineart -Ivan Shishkin,0.6734136,fineart -Hans Holbein the Elder,0.6733856,fineart -Filip Hodas,0.60053295,digipa-high-impact -Herman Saftleven,0.6732188,digipa-high-impact -Dirck de Quade van Ravesteyn,0.67309594,fineart -Joe Fenton,0.6730916,scribbles -Arnold Bocklin,0.6730706,fineart -Baiōken Eishun,0.6730663,anime -Giovanni Giacometti,0.6730505,fineart -Giovanni Battista Gaulli,0.65036476,fineart -William Stout,0.672887,fineart -Gavin Hamilton,0.5982757,fineart -John Stezaker,0.6726847,black-white -Frederick McCubbin,0.67263377,fineart -Christoph Ludwig Agricola,0.62750757,fineart -Alice Neel,0.67255914,scribbles -Giovanni Battista Venanzi,0.61996603,fineart -Miho Hirano,0.6724092,anime -Tom Thomson,0.6723876,fineart -Alfred Munnings,0.6723851,fineart -David Wilkie,0.6722781,fineart -Adriaen van Ostade,0.67220736,fineart -Alfred Eisenstaedt,0.67213774,black-white -Leon Kossoff,0.67208946,fineart -Georges de La Tour,0.6421979,fineart -Chuck Close,0.6719756,digipa-high-impact -Herbert MacNair,0.6719506,scribbles -Edward Atkinson Hornel,0.6719265,fineart -Becky Cloonan,0.67192084,cartoon -Gian Lorenzo Bernini,0.58210254,fineart -Hein Gorny,0.4982776,digipa-med-impact -Joe Webb,0.6714884,fineart -Cornelis Pietersz Bega,0.64423996,fineart -Christian Krohg,0.6713641,fineart -Cornelia Parker,0.6712246,fineart -Anna Mary Robertson Moses,0.6709144,fineart -Quentin Tarantino,0.6708354,digipa-high-impact -Frederic Remington,0.67074275,fineart -Barent Fabritius,0.6707407,fineart -Oleg Oprisco,0.6707388,digipa-high-impact -Hendrick van Streeck,0.670666,fineart -Bakemono Zukushi,0.67051035,anime -Lucy Madox Brown,0.67032814,fineart -Paul Wonner,0.6700563,scribbles -Guido Borelli Da Caluso,0.66966087,digipa-high-impact -Emil Alzamora,0.5844039,nudity -Heinrich Brocksieper,0.64469147,fineart -Dan Smith,0.669563,digipa-high-impact -Lois van Baarle,0.6695091,scribbles -Arthur Garfield Dove,0.6694996,scribbles -Matthias Jung,0.66936135,digipa-high-impact -José Clemente Orozco,0.6693544,scribbles -Don Bluth,0.6693046,cartoon -Akseli Gallen-Kallela,0.66927314,fineart -Alex Howitt,0.52858865,digipa-high-impact -Giovanni Bernardino Asoleni,0.6635405,fineart -Frederick Goodall,0.6690712,fineart -Francesco Bartolozzi,0.63431,fineart -Edmund Leighton,0.6689639,fineart -Abraham Willaerts,0.5966594,fineart -François Louis Thomas Francia,0.6207474,fineart -Carel Fabritius,0.6688478,fineart -Flora Macdonald Reid,0.6687404,fineart -Bartholomeus Breenbergh,0.6163084,fineart -Bernardino Mei,0.6486895,fineart -Carel Weight,0.6684968,fineart -Aristide Maillol,0.66843045,scribbles -Chris Leib,0.60567486,fineart -Giovanni Battista Piazzetta,0.65012705,fineart -Daniel Maclise,0.6678073,fineart -Giovanni Bernardino Azzolini,0.65774256,fineart -Aaron Horkey,0.6676864,fineart -Otto Dix,0.667294,scribbles -Ferdinand Bol,0.6414797,fineart -Adriaen Coorte,0.6670663,fineart -William Gropper,0.6669881,scribbles -Gerard de Lairesse,0.6639489,fineart -Mab Graves,0.6668356,scribbles -Fernando Amorsolo,0.66683346,fineart -Pixar Concept Artists,0.6667752,cartoon -Alfred Augustus Glendening,0.64009607,fineart -Diego Velázquez,0.6666799,fineart -Jerry Pinkney,0.6665478,fineart -Antoine Wiertz,0.6143825,fineart -Alberto Burri,0.6618252,scribbles -Max Weber,0.6664029,fineart -Hans Baluschek,0.66636246,fineart -Annie Swynnerton,0.6663346,fineart -Albert Dubois-Pillet,0.57526016,fineart -Dora Maar,0.62862253,digipa-high-impact -Kay Sage,0.5614823,fineart -David A. Hardy,0.51376164,fineart -Alberto Biasi,0.42917693,digipa-low-impact -Fra Bartolomeo,0.6661105,fineart -Hendrick van Balen,0.65754294,fineart -Edwin Austin Abbey,0.66596496,fineart -George Frederic Watts,0.66595024,fineart -Alexei Kondratyevich Savrasov,0.6470352,fineart -Anna Ancher,0.66581213,fineart -Irma Stern,0.66580737,fineart -Frédéric Bazille,0.6657115,fineart -Awataguchi Takamitsu,0.6656272,anime -Edward Sorel,0.6655388,fineart -Edward Lear,0.6655078,fineart -Gabriel Metsu,0.6654555,fineart -Giovanni Battista Innocenzo Colombo,0.6653655,fineart -Scott Naismith,0.6650656,fineart -John Perceval,0.6650283,fineart -Girolamo Muziano,0.64234406,fineart -Cornelis de Man,0.66494393,fineart -Cornelis Bisschop,0.64119905,digipa-high-impact -Hans Leu the Elder,0.64770013,fineart -Michael Hutter,0.62479556,fineart -Cornelia MacIntyre Foley,0.6510235,fineart -Todd McFarlane,0.6647763,cartoon -John James Audubon,0.6279882,digipa-high-impact -William Henry Hunt,0.57340264,fineart -John Anster Fitzgerald,0.6644317,fineart -Tomer Hanuka,0.6643152,cartoon -Alex Prager,0.6641814,fineart -Heinrich Kley,0.6641148,fineart -Anne Redpath,0.66407835,scribbles -Marianne North,0.6640104,fineart -Daniel Merriam,0.6639365,fineart -Bill Carman,0.66390574,fineart -Méret Oppenheim,0.66387725,digipa-high-impact -Erich Heckel,0.66384083,fineart -Iryna Yermolova,0.663623,fineart -Antoine Ignace Melling,0.61502695,fineart -Akira Toriyama,0.6635002,anime -Gregory Crewdson,0.59810174,digipa-high-impact -Helene Schjerfbeck,0.66333634,fineart -Antonio Mancini,0.6631618,fineart -Zanele Muholi,0.58554715,n -Balthasar van der Ast,0.66294503,fineart -Toei Animations,0.6629127,anime -Arthur Quartley,0.6628106,fineart -Diego Rivera,0.6625808,fineart -Hendrik van Steenwijk II,0.6623777,fineart -James Tissot,0.6623415,fineart -Kehinde Wiley,0.66218376,n -Chiharu Shiota,0.6621249,digipa-high-impact -George Grosz,0.6620224,fineart -Peter De Seve,0.6616659,cartoon -Ryan Hewett,0.6615638,fineart -Hasegawa Tōhaku,0.66146004,anime -Apollinary Vasnetsov,0.6613177,fineart -Francis Cadell,0.66119456,fineart -Henri Harpignies,0.6611012,fineart -Henry Macbeth-Raeburn,0.6213787,fineart -Christoffel van den Berghe,0.6609149,fineart -Leiji Matsumoto,0.66089404,anime -Adriaen van der Werff,0.638286,fineart -Ramon Casas,0.6606529,fineart -Arthur Hacker,0.66062653,fineart -Edward Willis Redfield,0.66058433,fineart -Carl Gustav Carus,0.65355223,fineart -Francesca Woodman,0.60435605,digipa-high-impact -Hans Makart,0.5881955,fineart -Carne Griffiths,0.660091,weird -Will Barnet,0.65995145,scribbles -Fitz Henry Lane,0.659841,fineart -Masaaki Sasamoto,0.6597158,anime -Salvador Dali,0.6290813,scribbles -Walt Kelly,0.6596993,digipa-high-impact -Charlotte Nasmyth,0.56481636,fineart -Ferdinand Knab,0.6596528,fineart -Steve Lieber,0.6596117,scribbles -Zhang Kechun,0.6595939,fareast -Olivier Valsecchi,0.5324838,digipa-high-impact -Joel Meyerowitz,0.65937585,digipa-high-impact -Arthur Streeton,0.6592294,fineart -Henriett Seth F.,0.6592273,fineart -Genndy Tartakovsky,0.6591695,scribbles -Otto Marseus van Schrieck,0.65890455,fineart -Hanna-Barbera,0.6588123,cartoon -Mary Anning,0.6588001,fineart -Pamela Colman Smith,0.6587648,fineart -Anton Mauve,0.6586873,fineart -Hendrick Avercamp,0.65866685,fineart -Max Pechstein,0.65860206,scribbles -Franciszek Żmurko,0.56855476,fineart -Felice Casorati,0.6584761,fineart -Louis Janmot,0.65298057,fineart -Thomas Cole,0.5408042,fineart -Peter Mohrbacher,0.58273685,fineart -Arnold Franz Brasz,0.65834284,nudity -Christian Rohlfs,0.6582814,fineart -Basil Gogos,0.658105,fineart -Fitz Hugh Lane,0.657923,fineart -Liubov Sergeevna Popova,0.62325525,fineart -Elizabeth MacNicol,0.65773135,fineart -Zinaida Serebriakova,0.6577016,fineart -Ernest Lawson,0.6575238,fineart -Bruno Catalano,0.6574354,fineart -Albert Namatjira,0.6573372,fineart -Fritz von Uhde,0.6572697,fineart -Edwin Henry Landseer,0.62363374,fineart -Naoto Hattori,0.621745,fareast -Reylia Slaby,0.65709853,fineart -Arthur Burdett Frost,0.6147318,fineart -Frank Miller,0.65707314,digipa-high-impact -Algernon Talmage,0.65702903,fineart -Itō Jakuchū,0.6570199,digipa-high-impact -Billie Waters,0.65684533,digipa-high-impact -Ingrid Baars,0.58558,digipa-high-impact -Pieter Jansz Saenredam,0.6566058,fineart -Egbert van Heemskerck,0.6125889,fineart -John French Sloan,0.6362145,fineart -Craola,0.65639997,scribbles -Benjamin Marra,0.61809736,nudity -Anthony Thieme,0.65609205,fineart -Satoshi Kon,0.65606606,anime -Masamune Shirow,0.65592873,anime -Alfred Stevens,0.6557321,fineart -Hariton Pushwagner,0.6556745,anime -Carlo Carrà,0.6556279,fineart -Stuart Davis,0.6050534,digipa-high-impact -David Shrigley,0.6553904,digipa-high-impact -Albrecht Anker,0.65531695,fineart -Anton Semenov,0.6552501,digipa-high-impact -Fabio Hurtado,0.5955889,fineart -Donald Judd,0.6552257,fineart -Francisco de Burgos Mantilla,0.65516514,fineart -Barthel Bruyn the Younger,0.6551433,fineart -Abram Arkhipov,0.6550962,fineart -Paulus Potter,0.65498203,fineart -Edward Lamson Henry,0.6549521,fineart -Audrey Kawasaki,0.654843,fineart -George Catlin,0.6547183,fineart -Adélaïde Labille-Guiard,0.6066263,fineart -Sandy Skoglund,0.6546999,digipa-high-impact -Hans Baldung,0.654431,fineart -Ethan Van Sciver,0.65442884,cartoon -Frans Hals,0.6542338,fineart -Caspar David Friedrich,0.6542175,fineart -Charles Conder,0.65420866,fineart -Betty Churcher,0.65387225,fineart -Claes Corneliszoon Moeyaert,0.65386075,fineart -David Bomberg,0.6537477,fineart -Abraham Bosschaert,0.6535562,fineart -Giuseppe de Nittis,0.65354455,fineart -John La Farge,0.65342575,fineart -Frits Thaulow,0.65341854,fineart -John Duncan,0.6532379,fineart -Floris van Dyck,0.64900756,fineart -Anton Pieck,0.65310377,fineart -Roger Dean,0.6529647,nudity -Maximilian Pirner,0.65280807,fineart -Dorothy Johnstone,0.65267503,fineart -Govert Dircksz Camphuysen,0.65258145,fineart -Ryohei Hase,0.6168618,fineart -Hans von Aachen,0.62437224,fineart -Gustaf Munch-Petersen,0.6522485,fineart -Earnst Haeckel,0.6344333,fineart -Giovanni Battista Bracelli,0.62635326,fineart -Hendrick Goudt,0.6521433,fineart -Aneurin Jones,0.65191466,fineart -Bryan Hitch,0.6518333,cartoon -Coby Whitmore,0.6515695,fineart -Barthélemy d'Eyck,0.65156406,fineart -Quint Buchholz,0.65151155,fineart -Adriaen Hanneman,0.6514815,fineart -Tom Roberts,0.5855832,fineart -Fernand Khnopff,0.6512954,nudity -Charles Vess,0.6512271,cartoon -Carlo Galli Bibiena,0.6511681,nudity -Alexander Milne Calder,0.6081027,fineart -Josan Gonzalez,0.6193469,cartoon -Barthel Bruyn the Elder,0.6509954,fineart -Jon Whitcomb,0.6046063,fineart -Arcimboldo,0.6509897,fineart -Hendrik van Steenwijk I,0.65086293,fineart -Albert Joseph Pénot,0.65085316,fineart -Edward Wadsworth,0.6308917,scribbles -Andrew Wyeth,0.6507103,fineart -Correggio,0.650689,fineart -Frances Currey,0.65068,fineart -Henryk Siemiradzki,0.56721973,fineart -Worthington Whittredge,0.6504713,fineart -Federico Zandomeneghi,0.65033823,fineart -Isaac Levitan,0.6503356,fineart -Russ Mills,0.65012795,fineart -Edith Lawrence,0.65010095,fineart -Gil Elvgren,0.5614284,digipa-high-impact -Chris Foss,0.56495357,fineart -Francesco Zuccarelli,0.612805,fineart -Hendrick Bloemaert,0.64962655,fineart -Egon von Vietinghoff,0.57180583,fineart -Pixar,0.6495793,cartoon -Daniel Clowes,0.6495775,fineart -Friedrich Ritter von Friedländer-Malheim,0.6493772,fineart -Rebecca Sugar,0.6492679,scribbles -Chen Daofu,0.6492026,fineart -Dustin Nguyen,0.64909416,cartoon -Raymond Duchamp-Villon,0.6489605,nudity -Daniel Garber,0.6489332,fineart -Antonio Canova,0.58764786,fineart -Algernon Blackwood,0.59256804,fineart -Betye Saar,0.64877665,fineart -William S. Burroughs,0.5505619,fineart -Rodney Matthews,0.64844495,fineart -Michelangelo Buonarroti,0.6484401,fineart -Posuka Demizu,0.64843124,anime -Joao Ruas,0.6484134,fineart -Andy Fairhurst,0.6480388,special -"Andries Stock, Dutch Baroque painter",0.6479797,fineart -Antonio de la Gandara,0.6479292,fineart -Bruce Timm,0.6477877,scribbles -Harvey Kurtzman,0.64772683,cartoon -Eiichiro Oda,0.64772165,anime -Edwin Landseer,0.6166703,fineart -Carl Heinrich Bloch,0.64755356,fineart -Adriaen Isenbrant,0.6475428,fineart -Santiago Caruso,0.6473954,fineart -Alfred Guillou,0.6472603,fineart -Clara Peeters,0.64725095,fineart -Kim Jung Gi,0.6472225,cartoon -Milo Manara,0.6471776,cartoon -Phil Noto,0.6470769,anime -Kaws,0.6470336,cartoon -Desmond Morris,0.5951916,fineart -Gediminas Pranckevicius,0.6467787,fineart -Jack Kirby,0.6467424,cartoon -Claes Jansz. Visscher,0.6466888,fineart -Augustin Meinrad Bächtiger,0.6465789,fineart -John Lavery,0.64643383,fineart -Anne Bachelier,0.6464065,fineart -Giuseppe Bernardino Bison,0.64633006,fineart -E. T. A. Hoffmann,0.5887251,fineart -Ambrosius Benson,0.6457839,fineart -Cornelis Verbeeck,0.645782,fineart -H. R. Giger,0.6456823,weird -Adolph Menzel,0.6455246,fineart -Aliza Razell,0.5863178,digipa-high-impact -Gerard Seghers,0.6205679,fineart -David Aja,0.62812066,scribbles -Gustave Courbet,0.64476407,fineart -Alexandre Cabanel,0.63849115,fineart -Albert Marquet,0.64471006,fineart -Harold Harvey,0.64464307,fineart -William Wegman,0.6446265,scribbles -Harold Gilman,0.6445966,fineart -Jeremy Geddes,0.57839495,digipa-high-impact -Abraham van Beijeren,0.6356113,fineart -Eugène Isabey,0.6160607,fineart -Jorge Jacinto,0.58618563,fineart -Frederic Leighton,0.64383554,fineart -Dave McKean,0.6438012,cartoon -Hiromu Arakawa,0.64371413,anime -Aaron Douglas,0.6437089,fineart -Adolf Dietrich,0.590169,fineart -Frederik de Moucheron,0.6435952,fineart -Siya Oum,0.6435919,cartoon -Alberto Morrocco,0.64352196,fineart -Robert Vonnoh,0.6433115,fineart -Tom Bagshaw,0.5322264,fineart -Guerrilla Girls,0.64309967,digipa-high-impact -Johann Wolfgang von Goethe,0.6429888,fineart -Charles Le Roux,0.6426594,fineart -Auguste Toulmouche,0.64261353,fineart -Cindy Sherman,0.58666563,digipa-high-impact -Federico Zuccari,0.6425021,fineart -Mike Mignola,0.642346,cartoon -Cecily Brown,0.6421981,fineart -Brian K. Vaughan,0.64147836,cartoon -RETNA (Marquis Lewis),0.47963,n -Klaus Janson,0.64129144,cartoon -Alessandro Galli Bibiena,0.6412889,fineart -Jeremy Lipking,0.64123213,fineart -Stephen Shore,0.64108944,digipa-high-impact -Heinz Edelmann,0.51325977,digipa-med-impact -Joaquín Sorolla,0.6409732,fineart -Bella Kotak,0.6409608,digipa-high-impact -Cornelis Engebrechtsz,0.64091057,fineart -Bruce Munro,0.64084166,digipa-high-impact -Marjane Satrapi,0.64076495,fineart -Jeremy Mann,0.557744,digipa-high-impact -Heinrich Maria Davringhausen,0.6403986,fineart -Kengo Kuma,0.6402023,digipa-high-impact -Alfred Manessier,0.640153,fineart -Antonio Galli Bibiena,0.6399247,digipa-high-impact -Eduard von Grützner,0.6397164,fineart -Bunny Yeager,0.5455078,digipa-high-impact -Adolphe Willette,0.6396935,fineart -Wangechi Mutu,0.6394607,n -Peter Milligan,0.6391612,digipa-high-impact -Dalí,0.45400402,digipa-low-impact -Élisabeth Vigée Le Brun,0.6388982,fineart -Beth Conklin,0.6388204,digipa-high-impact -Charles Alphonse du Fresnoy,0.63881266,fineart -Thomas Benjamin Kennington,0.56668127,fineart -Jim Woodring,0.5625168,fineart -Francisco Oller,0.63846034,fineart -Csaba Markus,0.6384506,fineart -Botero,0.63843524,scribbles -Bill Henson,0.5394536,digipa-high-impact -Anna Bocek,0.6382304,scribbles -Hugo van der Goes,0.63822484,fineart -Robert William Hume,0.5433574,fineart -Chip Zdarsky,0.6381826,cartoon -Daniel Seghers,0.53494316,fineart -Richard Doyle,0.6377541,fineart -Hendrick Terbrugghen,0.63773805,fineart -Joe Madureira,0.6377177,special -Floris van Schooten,0.6376191,fineart -Jeff Simpson,0.3959046,fineart -Albert Joseph Moore,0.6374316,fineart -Arthur Merric Boyd,0.6373228,fineart -Amadeo de Souza Cardoso,0.5927926,fineart -Os Gemeos,0.6368859,digipa-high-impact -Giovanni Boldini,0.6368698,fineart -Albert Goodwin,0.6368695,fineart -Hans Eduard von Berlepsch-Valendas,0.61562145,fineart -Edmond Xavier Kapp,0.5758474,fineart -François Quesnel,0.6365935,fineart -Nathan Coley,0.6365817,digipa-high-impact -Jasmine Becket-Griffith,0.6365083,digipa-high-impact -Raphaelle Peale,0.6364422,fineart -Candido Portinari,0.63634276,fineart -Edward Dugmore,0.63179636,fineart -Anders Zorn,0.6361722,fineart -Ed Emshwiller,0.63615763,fineart -Francis Coates Jones,0.6361159,fineart -Ernst Haas,0.6361123,digipa-high-impact -Dirck van Baburen,0.6213001,fineart -René Lalique,0.63594735,fineart -Sydney Prior Hall,0.6359345,fineart -Brad Kunkle,0.5659712,fineart -Corneille,0.6356381,fineart -Henry Lamb,0.63560975,fineart -Dirck Hals,0.63559663,fineart -Alex Grey,0.62908936,nudity -Michael Heizer,0.63555753,fineart -Yiannis Moralis,0.61731136,fineart -Emily Murray Paterson,0.4392335,fineart -Georg Friedrich Kersting,0.6256248,fineart -Frances Hodgkins,0.6352128,fineart -Charles Cundall,0.6349486,fineart -Henry Wallis,0.63478243,fineart -Goro Fujita,0.6346491,cartoon -Jean-Léon Gérôme,0.5954844,fineart -August von Pettenkofen,0.60910493,fineart -Abbott Handerson Thayer,0.63428533,fineart -Martin John Heade,0.5926603,fineart -Ellen Jewett,0.63420236,digipa-high-impact -Hidari Jingorō,0.63388014,fareast -Taiyō Matsumoto,0.63372946,special -Emanuel Leutze,0.6007246,fineart -Adam Martinakis,0.48973057,digipa-med-impact -Will Eisner,0.63349223,cartoon -Alexander Stirling Calder,0.6331682,fineart -Saturno Butto,0.6331184,nudity -Cecilia Beaux,0.6330725,fineart -Amandine Van Ray,0.6174208,digipa-high-impact -Bob Eggleton,0.63277495,digipa-high-impact -Sherree Valentine Daines,0.63274443,fineart -Frederick Lord Leighton,0.6299176,fineart -Daniel Ridgway Knight,0.63251615,fineart -Gaetano Previati,0.61743724,fineart -John Berkey,0.63226986,fineart -Richard Misrach,0.63201725,digipa-high-impact -Aaron Jasinski,0.57948315,fineart -"Edward Otho Cresap Ord, II",0.6317712,fineart -Evelyn De Morgan,0.6317376,fineart -Noelle Stevenson,0.63159716,digipa-high-impact -Edward Robert Hughes,0.6315573,fineart -Allan Ramsay,0.63150716,fineart -Balthus,0.6314323,scribbles -Hendrick Cornelisz Vroom,0.63143134,digipa-high-impact -Ilya Repin,0.6313043,fineart -George Lambourn,0.6312267,fineart -Arthur Hughes,0.6310194,fineart -Antonio J. Manzanedo,0.53841716,fineart -John Singleton Copley,0.6264835,fineart -Dennis Miller Bunker,0.63078755,fineart -Ernie Barnes,0.6307126,cartoon -Alison Kinnaird,0.6306353,digipa-high-impact -Alex Toth,0.6305541,digipa-high-impact -Henry Raeburn,0.6155551,fineart -Alice Bailly,0.6305177,fineart -Brian Kesinger,0.63037646,scribbles -Antoine Blanchard,0.63036835,fineart -Ron Walotsky,0.63035095,fineart -Kent Monkman,0.63027304,fineart -Naomi Okubo,0.5782754,fareast -Hercules Seghers,0.62957174,fineart -August Querfurt,0.6295643,fineart -Samuel Melton Fisher,0.6283333,fineart -David Burdeny,0.62950236,digipa-high-impact -George Bain,0.58519644,fineart -Peter Holme III,0.62938106,fineart -Grayson Perry,0.62928164,digipa-high-impact -Chris Claremont,0.6292076,digipa-high-impact -Dod Procter,0.6291759,fineart -Huang Tingjian,0.6290358,fareast -Dorothea Warren O'Hara,0.6290113,fineart -Ivan Albright,0.6289551,fineart -Hubert von Herkomer,0.6288955,fineart -Barbara Nessim,0.60589516,digipa-high-impact -Henry Scott Tuke,0.6286309,fineart -Ditlev Blunck,0.6282925,fineart -Sven Nordqvist,0.62828535,fineart -Lee Madgwick,0.6281731,fineart -Hubert van Eyck,0.6281529,fineart -Edmond Bille,0.62339354,fineart -Ejnar Nielsen,0.6280824,fineart -Arturo Souto,0.6280583,fineart -Jean Giraud,0.6279888,fineart -Storm Thorgerson,0.6277394,digipa-high-impact -Ed Benedict,0.62764007,digipa-high-impact -Christoffer Wilhelm Eckersberg,0.6014842,fineart -Clarence Holbrook Carter,0.5514105,fineart -Dorothy Lockwood,0.6273235,fineart -John Singer Sargent,0.6272487,fineart -Brigid Derham,0.6270125,digipa-high-impact -Henricus Hondius II,0.6268505,fineart -Gertrude Harvey,0.5903887,fineart -Grant Wood,0.6266253,fineart -Fyodor Vasilyev,0.5234919,digipa-med-impact -Cagnaccio di San Pietro,0.6261671,fineart -Doris Boulton-Maude,0.62593174,fineart -Adolf Hirémy-Hirschl,0.5946784,fineart -Harold von Schmidt,0.6256755,fineart -Martine Johanna,0.6256161,digipa-high-impact -Gerald Kelly,0.5579602,digipa-high-impact -Ub Iwerks,0.625396,cartoon -Dirck van der Lisse,0.6253871,fineart -Edouard Riou,0.6250113,fineart -Ilya Yefimovich Repin,0.62491584,fineart -Martin Johnson Heade,0.59421235,fineart -Afarin Sajedi,0.62475824,scribbles -Alfred Thompson Bricher,0.6247515,fineart -Edwin G. Lucas,0.5553578,fineart -Georges Emile Lebacq,0.56175387,fineart -Francis Davis Millet,0.5988504,fineart -Bill Sienkiewicz,0.6125557,digipa-high-impact -Giocondo Albertolli,0.62441677,fineart -Victor Nizovtsev,0.6242258,fineart -Squeak Carnwath,0.62416434,digipa-high-impact -Bill Viola,0.62409425,digipa-high-impact -Annie Abernethie Pirie Quibell,0.6240767,fineart -Jason Edmiston,0.62405366,fineart -Al Capp,0.6239494,fineart -Kobayashi Kiyochika,0.6239368,anime -Albert Anker,0.62389827,fineart -Iain Faulkner,0.62376785,fineart -Todd Schorr,0.6237408,fineart -Charles Ginner,0.62370133,fineart -Emile Auguste Carolus-Duran,0.62353987,fineart -John Philip Falter,0.623418,cartoon -Chizuko Yoshida,0.6233001,fareast -Anna Dittmann,0.62327325,cartoon -Henry Snell Gamley,0.62319934,fineart -Edmund Charles Tarbell,0.6230626,fineart -Rob Gonsalves,0.62298363,fineart -Gladys Dawson,0.6228511,fineart -Tomma Abts,0.61153626,fineart -Kate Beaton,0.53993124,digipa-high-impact -Gustave Buchet,0.62243867,fineart -Gareth Pugh,0.6223551,digipa-high-impact -Caspar van Wittel,0.57871693,fineart -Anton Otto Fischer,0.6222941,fineart -Albert Guillaume,0.56529653,fineart -Felix Octavius Carr Darley,0.62223387,fineart -Bernard van Orley,0.62221646,fineart -Edward John Poynter,0.60147405,fineart -Walter Percy Day,0.62207425,fineart -Franciszek Starowieyski,0.5709621,fineart -Auguste Baud-Bovy,0.6219854,fineart -Chris LaBrooy,0.45497298,digipa-low-impact -Abraham de Vries,0.5859101,fineart -Antoni Gaudi,0.62162614,fineart -Joe Jusko,0.62156093,digipa-high-impact -Lynda Barry,0.62154603,digipa-high-impact -Michal Karcz,0.62154436,digipa-high-impact -Raymond Briggs,0.62150294,fineart -Herbert James Gunn,0.6210927,fineart -Dwight William Tryon,0.620984,fineart -Paul Henry,0.5752968,fineart -Helio Oiticica,0.6203739,digipa-high-impact -Sebastian Errazuriz,0.62036186,digipa-high-impact -Lucian Freud,0.6203146,nudity -Frank Auerbach,0.6201102,weird -Andre-Charles Boulle,0.6200789,fineart -Franz Fedier,0.5669752,fineart -Austin Briggs,0.57675314,fineart -Hugo Sánchez Bonilla,0.61978436,digipa-high-impact -Caroline Chariot-Dayez,0.6195682,digipa-high-impact -Bill Ward,0.61953044,digipa-high-impact -Charles Bird King,0.6194487,fineart -Adrian Ghenie,0.6193521,digipa-high-impact -Agnes Cecile,0.6192814,digipa-high-impact -Augustus John,0.6191995,fineart -Jeffrey T. Larson,0.61913544,fineart -Alexis Simon Belle,0.3190395,digipa-low-impact -Jean-Baptiste Monge,0.5758537,fineart -Adolf Bierbrauer,0.56129396,fineart -Ayako Rokkaku,0.61891204,fareast -Lisa Keene,0.54570895,digipa-high-impact -Edmond Aman-Jean,0.57168096,fineart -Marc Davis,0.61837333,cartoon -Cerith Wyn Evans,0.61829346,digipa-high-impact -George Wyllie,0.61829203,fineart -George Luks,0.6182724,fineart -William-Adolphe Bouguereau,0.618265,c -Grigoriy Myasoyedov,0.61801606,fineart -Hashimoto Gahō,0.61795104,fineart -Charles Ragland Bunnell,0.61772746,fineart -Ambrose McCarthy Patterson,0.61764514,fineart -Bill Brauer,0.5824066,fineart -Mikko Lagerstedt,0.591015,digipa-high-impact -Koson Ohara,0.53635323,fineart -Evaristo Baschenis,0.5857368,fineart -Martin Ansin,0.5294119,fineart -Cory Loftis,0.6168619,cartoon -Joseph Stella,0.6166778,fineart -André Pijet,0.5768274,fineart -Jeff Wall,0.6162895,digipa-high-impact -Eleanor Layfield Davis,0.6158844,fineart -Saul Tepper,0.61579347,fineart -Alex Hirsch,0.6157384,cartoon -Alexandre Falguière,0.55011404,fineart -Malcolm Liepke,0.6155646,fineart -Georg Friedrich Schmidt,0.60364646,fineart -Hendrik Kerstens,0.55099905,digipa-high-impact -Félix Bódog Widder,0.6153954,fineart -Marie Guillemine Benoist,0.61532974,fineart -Kelly Mckernan,0.60047054,digipa-high-impact -Ignacio Zuloaga,0.6151608,fineart -Hubert van Ravesteyn,0.61489964,fineart -Angus McKie,0.61487424,digipa-high-impact -Colin Campbell Cooper,0.6147882,fineart -Pieter Aertsen,0.61454165,fineart -Jan Brett,0.6144608,fineart -Kazuo Koike,0.61438507,fineart -Edith Grace Wheatley,0.61428297,fineart -Ogawa Kazumasa,0.61427975,fareast -Giovanni Battista Cipriani,0.6022825,fineart -André Bauchant,0.57124996,fineart -George Abe,0.6140447,digipa-high-impact -Georges Lemmen,0.6139967,scribbles -Frank Leonard Brooks,0.6139327,fineart -Gai Qi,0.613744,anime -Frank Gehry,0.6136776,digipa-high-impact -Anton Domenico Gabbiani,0.55471313,fineart -Cassandra Austen,0.6135781,fineart -Paul Gustav Fischer,0.613273,fineart -Emiliano Di Cavalcanti,0.6131207,fineart -Meryl McMaster,0.6129995,digipa-high-impact -Domenico di Pace Beccafumi,0.6129922,fineart -Ludwig Mies van der Rohe,0.6126692,fineart -Étienne-Louis Boullée,0.6126158,fineart -Dali,0.5928694,nudity -Shinji Aramaki,0.61246127,anime -Giovanni Fattori,0.59544694,fineart -Bapu,0.6122084,c -Raphael Lacoste,0.5539114,digipa-high-impact -Scarlett Hooft Graafland,0.6119631,digipa-high-impact -Rene Laloux,0.61190474,fineart -Julius Horsthuis,0.59037095,fineart -Gerald van Honthorst,0.6115939,fineart -Dino Valls,0.611533,fineart -Tony DiTerlizzi,0.6114657,cartoon -Michael Cheval,0.61138546,anime -Charles Schulz,0.6113759,digipa-high-impact -Alvar Aalto,0.61122143,digipa-high-impact -Gu Kaizhi,0.6110798,fareast -Eugene von Guerard,0.6109776,fineart -John Cassaday,0.610949,fineart -Elizabeth Forbes,0.61092335,fineart -Edmund Greacen,0.6109115,fineart -Eugène Burnand,0.6107876,fineart -Boris Grigoriev,0.6107853,scribbles -Norman Rockwell,0.6107638,fineart -Barthélemy Menn,0.61064315,fineart -George Biddle,0.61058354,fineart -Edgar Ainsworth,0.5525424,digipa-high-impact -Alfred Leyman,0.5887217,fineart -Tex Avery,0.6104007,cartoon -Beatrice Ethel Lithiby,0.61030364,fineart -Grace Pailthorpe,0.61026484,digipa-high-impact -Brian Oldham,0.396231,digipa-low-impact -Android Jones,0.61023116,fareast -François Girardon,0.5830649,fineart -Ib Eisner,0.61016303,digipa-high-impact -Armand Point,0.610156,fineart -Henri Alphonse Barnoin,0.59465057,fineart -Jean Marc Nattier,0.60987425,fineart -Francisco de Holanda,0.6091294,fineart -Marco Mazzoni,0.60970783,fineart -Esaias Boursse,0.6093308,fineart -Alexander Deyneka,0.55000365,fineart -John Totleben,0.60883725,fineart -Al Feldstein,0.6087723,fineart -Adam Hughes,0.60854626,anime -Ernest Zobole,0.6085073,fineart -Alex Gross,0.60837066,digipa-high-impact -George Jamesone,0.6079673,fineart -Frank Lloyd Wright,0.60793245,scribbles -Brooke DiDonato,0.47680336,digipa-med-impact -Hans Gude,0.60780364,fineart -Ethel Schwabacher,0.60748273,fineart -Gladys Kathleen Bell,0.60747695,fineart -Adolf Fényes,0.54192233,fineart -Carel Willink,0.58120143,fineart -George Henry,0.6070727,digipa-high-impact -Ronald Balfour,0.60697085,fineart -Elsie Dalton Hewland,0.6067718,digipa-high-impact -Alex Maleev,0.6067118,fineart -Anish Kapoor,0.6067015,digipa-high-impact -Aleksandr Ivanovich Laktionov,0.606544,fineart -Kim Keever,0.6037775,digipa-high-impact -Aleksi Briclot,0.46056762,fineart -Raymond Leech,0.6062721,fineart -Richard Eurich,0.6062664,fineart -Phil Jimenez,0.60625625,cartoon -Gao Cen,0.60618126,nudity -Mike Deodato,0.6061201,cartoon -Charles Haslewood Shannon,0.6060581,fineart -Alexandre Jacovleff,0.3991747,digipa-low-impact -André Beauneveu,0.584062,fineart -Hiroshi Honda,0.60507596,digipa-high-impact -Charles Joshua Chaplin,0.60498774,fineart -Domenico Zampieri,0.6049726,fineart -Gusukuma Seihō,0.60479784,fareast -Nikolina Petolas,0.46318632,digipa-low-impact -Casey Weldon,0.6047672,cartoon -Elmyr de Hory,0.6046374,fineart -Nan Goldin,0.6046119,digipa-high-impact -Charles McAuley,0.6045995,fineart -Archibald Skirving,0.6044234,fineart -Elizabeth York Brunton,0.6043737,fineart -Dugald Sutherland MacColl,0.6042907,fineart -Titian,0.60426414,fineart -Ignacy Witkiewicz,0.6042259,fineart -Allie Brosh,0.6042061,digipa-high-impact -H.P. Lovecraft,0.6039597,digipa-high-impact -Andrée Ruellan,0.60395086,fineart -Ralph McQuarrie,0.60380936,fineart -Mead Schaeffer,0.6036558,fineart -Henri-Julien Dumont,0.571257,fineart -Kieron Gillen,0.6035093,fineart -Maginel Wright Enright Barney,0.6034306,nudity -Vincent Di Fate,0.6034131,fineart -Briton Rivière,0.6032918,fineart -Hajime Sorayama,0.60325956,nudity -Béla Czóbel,0.6031023,fineart -Edmund Blampied,0.603072,fineart -E. Simms Campbell,0.6030443,fineart -Hisui Sugiura,0.603034,fareast -Alan Davis,0.6029676,fineart -Glen Keane,0.60287905,cartoon -Frank Holl,0.6027312,fineart -Abbott Fuller Graves,0.6025608,fineart -Albert Servaes,0.60250103,black-white -Hovsep Pushman,0.5937487,fineart -Brian M. Viveros,0.60233414,fineart -Charles Fremont Conner,0.6023278,fineart -Francesco Furini,0.6022654,digipa-high-impact -Camille-Pierre Pambu Bodo,0.60191673,fineart -Yasushi Nirasawa,0.6016714,nudity -Charles Uzzell-Edwards,0.6014683,fineart -Abram Efimovich Arkhipov,0.60128385,fineart -Hedda Sterne,0.6011857,digipa-high-impact -Ben Aronson,0.6011548,fineart -Frank Frazetta,0.551121,nudity -Elizabeth Durack,0.6010842,fineart -Ian Miller,0.42153555,fareast -Charlie Bowater,0.4410439,special -Michael Carson,0.60039437,fineart -Walter Langley,0.6002273,fineart -Cornelis Anthonisz,0.6001956,fineart -Dorothy Elizabeth Bradford,0.6001929,fineart -J.C. Leyendecker,0.5791972,fineart -Willem van Haecht,0.59990716,fineart -Anna and Elena Balbusso,0.59955937,digipa-low-impact -Harrison Fisher,0.59952044,fineart -Bill Medcalf,0.59950054,fineart -Edward Arthur Walton,0.59945667,fineart -Alois Arnegger,0.5991994,fineart -Ray Caesar,0.59902894,digipa-high-impact -Karen Wallis,0.5990094,fineart -Emmanuel Shiu,0.51082766,digipa-med-impact -Thomas Struth,0.5988324,digipa-high-impact -Barbara Longhi,0.5985706,fineart -Richard Deacon,0.59851056,fineart -Constantin Hansen,0.5984213,fineart -Harold Shapinsky,0.5984175,fineart -George Dionysus Ehret,0.5983857,fineart -Doug Wildey,0.5983639,digipa-high-impact -Fernand Toussaint,0.5982694,fineart -Horatio Nelson Poole,0.5982614,fineart -Caesar van Everdingen,0.5981566,fineart -Eva Gonzalès,0.5981396,fineart -Franz Vohwinkel,0.5448179,fineart -Margaret Mee,0.5979592,fineart -Francis Focer Brown,0.59779185,fineart -Henry Moore,0.59767926,nudity -Scott Listfield,0.58795893,fineart -Nikolai Ge,0.5973643,fineart -Jacek Yerka,0.58198756,fineart -Margaret Brundage,0.5969077,fineart -JC Leyendecker,0.5620243,fineart -Ben Templesmith,0.5498991,digipa-high-impact -Armin Hansen,0.59669334,anime -Jean-Louis Prevost,0.5966897,fineart -Daphne Allen,0.59666026,fineart -Franz Karl Basler-Kopp,0.59663445,fineart -"Henry Ives Cobb, Jr.",0.596385,fineart -Michael Sowa,0.546285,fineart -Anna Füssli,0.59600973,fineart -György Rózsahegyi,0.59580946,fineart -Luis Royo,0.59566617,fineart -Émile Gallé,0.5955559,fineart -Antonio Mora,0.5334297,digipa-high-impact -Edward P. Beard Jr.,0.59543866,fineart -Jessica Rossier,0.54958373,special -André Thomkins,0.5343785,digipa-high-impact -David Macbeth Sutherland,0.5949968,fineart -Charles Liu,0.5949787,digipa-high-impact -Edi Rama,0.5949226,digipa-high-impact -Jacques Le Moyne,0.5948843,fineart -Egbert van der Poel,0.59488285,fineart -Georg Jensen,0.594782,digipa-high-impact -Anne Sudworth,0.5947539,fineart -Jan Pietersz Saenredam,0.59472525,fineart -Henryk Stażewski,0.5945748,fineart -André François,0.58402044,fineart -Alexander Runciman,0.5944449,digipa-high-impact -Thomas Kinkade,0.594391,fineart -Robert Williams,0.5567989,digipa-high-impact -George Gardner Symons,0.57431924,fineart -D. Alexander Gregory,0.5334464,fineart -Gerald Brom,0.52473724,fineart -Robert Hagan,0.59406,fineart -Ernest Crichlow,0.5940588,fineart -Viviane Sassen,0.5939927,digipa-high-impact -Enrique Simonet,0.5937546,fineart -Esther Blaikie MacKinnon,0.593747,digipa-high-impact -Jeff Kinney,0.59372896,scribbles -Igor Morski,0.5936732,digipa-high-impact -John Currin,0.5936216,fineart -Bob Ringwood,0.5935273,digipa-high-impact -Jordan Grimmer,0.44948143,digipa-low-impact -François Barraud,0.5933471,fineart -Helen Binyon,0.59331006,digipa-high-impact -Brenda Chamberlain,0.5932333,fineart -Candido Bido,0.59310603,fineart -Abraham Storck,0.5929502,fineart -Raphael,0.59278333,fineart -Larry Sultan,0.59273386,digipa-high-impact -Agostino Tassi,0.59265685,fineart -Alexander V. Kuprin,0.5925917,fineart -Frans Koppelaar,0.5658725,fineart -Richard Corben,0.59251785,fineart -David Gilmour Blythe,0.5924247,digipa-high-impact -František Kaván,0.5924211,fineart -Rob Liefeld,0.5921167,fineart -Ernő Rubik,0.5920297,fineart -Byeon Sang-byeok,0.59200096,fareast -Johfra Bosschart,0.5919376,fineart -Emil Lindenfeld,0.5761086,fineart -Howard Mehring,0.5917471,fineart -Gwenda Morgan,0.5915571,digipa-high-impact -Henry Asencio,0.5915404,fineart -"George Barret, Sr.",0.5914306,fineart -Andrew Ferez,0.5911011,fineart -Ed Brubaker,0.5910869,digipa-high-impact -George Reid,0.59095883,digipa-high-impact -Derek Gores,0.51769906,digipa-med-impact -Charles Rollier,0.5539186,fineart -Terry Oakes,0.590443,fineart -Thomas Blackshear,0.5078616,fineart -Albert Benois,0.5902705,nudity -Krenz Cushart,0.59026587,special -Jeff Koons,0.5902637,digipa-high-impact -Akihiko Yoshida,0.5901294,special -Anja Percival,0.45039332,digipa-low-impact -Eduard von Steinle,0.59008586,fineart -Alex Russell Flint,0.5900352,digipa-high-impact -Edward Okuń,0.5897297,fineart -Emma Lampert Cooper,0.5894849,fineart -Stuart Haygarth,0.58132994,digipa-high-impact -George French Angas,0.5434376,fineart -Edmund F. Ward,0.5892848,fineart -Eleanor Vere Boyle,0.58925456,digipa-high-impact -Evelyn Cheston,0.58924586,fineart -Edwin Dickinson,0.58921975,digipa-high-impact -Christophe Vacher,0.47325426,fineart -Anne Dewailly,0.58905107,fineart -Gertrude Greene,0.5862596,digipa-high-impact -Boris Groh,0.5888809,digipa-high-impact -Douglas Smith,0.588804,digipa-high-impact -Ian Hamilton Finlay,0.5887713,fineart -Derek Jarman,0.5887292,digipa-high-impact -Archibald Thorburn,0.5882001,fineart -Gillis d'Hondecoeter,0.58813053,fineart -I Ketut Soki,0.58801544,digipa-high-impact -Alex Schomburg,0.46614102,digipa-low-impact -Bastien L. Deharme,0.583349,special -František Jakub Prokyš,0.58782333,fineart -Jesper Ejsing,0.58782053,fineart -Odd Nerdrum,0.53551745,digipa-high-impact -Tom Lovell,0.5877577,fineart -Ayami Kojima,0.5877416,fineart -Peter Sculthorpe,0.5875696,fineart -Bernard D’Andrea,0.5874042,fineart -Denis Eden,0.58739066,digipa-high-impact -Alfons Walde,0.58728385,fineart -Jovana Rikalo,0.47006977,digipa-low-impact -Franklin Booth,0.5870834,fineart -Mat Collishaw,0.5870676,digipa-high-impact -Joseph Lorusso,0.586858,fineart -Helen Stevenson,0.454647,digipa-low-impact -Delaunay,0.58657396,fineart -H.R. Millar,0.58655745,fineart -E. Charlton Fortune,0.586376,fineart -Alson Skinner Clark,0.58631575,fineart -Stan And Jan Berenstain,0.5862361,digipa-high-impact -Howard Lyon,0.5862271,fineart -John Blanche,0.586182,fineart -Bernardo Cavallino,0.5858575,fineart -Tomasz Alen Kopera,0.5216588,fineart -Peter Gric,0.58583695,fineart -Guo Pei,0.5857794,fareast -James Turrell,0.5853901,digipa-high-impact -Alexandr Averin,0.58533764,fineart -Bertalan Székely,0.5548113,digipa-high-impact -Brothers Hildebrandt,0.5850233,fineart -Ed Roth,0.5849769,digipa-high-impact -Enki Bilal,0.58492255,fineart -Alan Lee,0.5848701,fineart -Charles H. Woodbury,0.5848688,fineart -André Charles Biéler,0.5847876,fineart -Annie Rose Laing,0.5597829,fineart -Matt Fraction,0.58463776,cartoon -Charles Alston,0.58453286,fineart -Frank Xavier Leyendecker,0.545465,fineart -Alfred Richard Gurrey,0.584306,fineart -Dan Mumford,0.5843051,cartoon -Francisco Martín,0.5842005,fineart -Alvaro Siza,0.58406967,digipa-high-impact -Frank J. Girardin,0.5839858,fineart -Henry Carr,0.58397424,digipa-high-impact -Charles Furneaux,0.58394694,fineart -Daniel F. Gerhartz,0.58389103,fineart -Gilberto Soren Zaragoza,0.5448442,fineart -Bart Sears,0.5838427,cartoon -Allison Bechdel,0.58383805,digipa-high-impact -Frank O'Meara,0.5837992,fineart -Charles Codman,0.5836579,fineart -Francisco Zúñiga,0.58359766,fineart -Vladimir Kush,0.49075457,fineart -Arnold Mesches,0.5834257,fineart -Frank McKelvey,0.5831641,fineart -Allen Butler Talcott,0.5830911,fineart -Eric Zener,0.58300316,fineart -Noah Bradley,0.44176096,digipa-low-impact -Robert Childress,0.58289623,fineart -Frances C. Fairman,0.5827239,fineart -Kathryn Morris Trotter,0.465856,digipa-low-impact -Everett Raymond Kinstler,0.5824819,fineart -Edward Mitchell Bannister,0.5804899,fineart -"George Barret, Jr.",0.5823128,fineart -Greg Hildebrandt,0.4271311,fineart -Anka Zhuravleva,0.5822078,digipa-high-impact -Rolf Armstrong,0.58217514,fineart -Eric Wallis,0.58191466,fineart -Clemens Ascher,0.5480207,digipa-high-impact -Hugo Kārlis Grotuss,0.5818766,fineart -Albert Paris Gütersloh,0.5817827,fineart -Hilda May Gordon,0.5817449,fineart -Hendrik Martenszoon Sorgh,0.5817126,fineart -Pipilotti Rist,0.5816868,digipa-high-impact -Hiroyuki Tajima,0.5816242,fareast -Igor Zenin,0.58159757,digipa-high-impact -Genevieve Springston Lynch,0.4979099,digipa-med-impact -Dan Witz,0.44476372,fineart -David Roberts,0.5255326,fineart -Frieke Janssens,0.5706969,digipa-high-impact -Arnold Schoenberg,0.56520367,fineart -Inoue Naohisa,0.5809933,fareast -Elfriede Lohse-Wächtler,0.58097905,fineart -Alex Ross,0.42460668,digipa-low-impact -Robert Irwin,0.58078,c -Charles Angrand,0.58077514,fineart -Anne Nasmyth,0.54221964,fineart -Henri Bellechose,0.5773891,fineart -De Hirsh Margules,0.58059025,fineart -Hiromitsu Takahashi,0.5805599,fareast -Ilya Kuvshinov,0.5805521,special -Cassius Marcellus Coolidge,0.5805516,c -Dorothy Burroughes,0.5804835,fineart -Emanuel de Witte,0.58027405,fineart -George Herbert Baker,0.5799624,digipa-high-impact -Cheng Zhengkui,0.57990086,fareast -Bernard Fleetwood-Walker,0.57987773,digipa-high-impact -Philippe Parreno,0.57985014,digipa-high-impact -Thornton Oakley,0.57969713,fineart -Greg Rutkowski,0.5203395,special -Ike no Taiga,0.5795857,anime -Eduardo Lefebvre Scovell,0.5795808,fineart -Adolfo Müller-Ury,0.57944727,fineart -Patrick Woodroffe,0.5228063,fineart -Wim Crouwel,0.57933235,digipa-high-impact -Colijn de Coter,0.5792779,fineart -François Boquet,0.57924724,fineart -Gerbrand van den Eeckhout,0.57897866,fineart -Eugenio Granell,0.5392264,fineart -Kuang Hong,0.5782304,digipa-high-impact -Justin Gerard,0.46685404,fineart -Tokujin Yoshioka,0.5779153,digipa-high-impact -Alan Bean,0.57788515,fineart -Ernest Biéler,0.5778079,fineart -Martin Deschambault,0.44401115,digipa-low-impact -Anna Boch,0.577735,fineart -Jack Davis,0.5775291,fineart -Félix Labisse,0.5775142,fineart -Greg Simkins,0.5679761,fineart -David Lynch,0.57751054,digipa-low-impact -Eizō Katō,0.5774127,digipa-high-impact -Grethe Jürgens,0.5773412,digipa-high-impact -Heinrich Bichler,0.5770147,fineart -Barbara Nasmyth,0.5446056,fineart -Domenico Induno,0.5583946,fineart -Gustave Baumann,0.5607866,fineart -Mike Mayhew,0.5765857,cartoon -Delmer J. Yoakum,0.576538,fineart -Aykut Aydogdu,0.43111503,digipa-low-impact -George Barker,0.5763551,fineart -Ernő Grünbaum,0.57634187,fineart -Eliseu Visconti,0.5763241,fineart -Esao Andrews,0.5761547,fineart -JennyBird Alcantara,0.49165845,digipa-med-impact -Joan Tuset,0.5761051,fineart -Angela Barrett,0.55976534,digipa-high-impact -Syd Mead,0.5758396,fineart -Ignacio Bazan-Lazcano,0.5757512,fineart -Franciszek Kostrzewski,0.57570386,fineart -Eero Järnefelt,0.57540673,fineart -Loretta Lux,0.56217635,digipa-high-impact -Gaudi,0.57519895,fineart -Charles Gleyre,0.57490873,fineart -Antoine Verney-Carron,0.56386137,fineart -Albert Edelfelt,0.57466495,fineart -Fabian Perez,0.57444525,fineart -Kevin Sloan,0.5737548,fineart -Stanislav Poltavsky,0.57434607,fineart -Abraham Hondius,0.574326,fineart -Tadao Ando,0.57429105,fareast -Fyodor Slavyansky,0.49796474,digipa-med-impact -David Brewster,0.57385933,digipa-high-impact -Cliff Chiang,0.57375133,digipa-high-impact -Drew Struzan,0.5317983,digipa-high-impact -Henry O. Tanner,0.5736586,fineart -Alberto Sughi,0.5736495,fineart -Albert J. Welti,0.5736257,fineart -Charles Mahoney,0.5735923,digipa-high-impact -Exekias,0.5734506,fineart -Felipe Seade,0.57342744,digipa-high-impact -Henriette Wyeth,0.57330644,digipa-high-impact -Harold Sandys Williamson,0.5443646,fineart -Eddie Campbell,0.57329535,digipa-high-impact -Gao Fenghan,0.5732926,fareast -Cynthia Sheppard,0.51099646,fineart -Henriette Grindat,0.573179,fineart -Yasutomo Oka,0.5731342,fareast -Celia Frances Bedford,0.57313216,fineart -Les Edwards,0.42068473,fineart -Edwin Deakin,0.5031717,fineart -Eero Saarinen,0.5725142,digipa-high-impact -Franciszek Smuglewicz,0.5722554,fineart -Doris Blair,0.57221186,fineart -Seb Mckinnon,0.51721895,digipa-med-impact -Gregorio Lazzarini,0.57204294,fineart -Gerard Sekoto,0.5719927,fineart -Francis Ernest Jackson,0.5506009,fineart -Simon Birch,0.57171595,digipa-high-impact -Bayard Wu,0.57171166,fineart -François Clouet,0.57162094,fineart -Christopher Wren,0.5715372,fineart -Evgeny Lushpin,0.5714827,special -Art Green,0.5714495,digipa-high-impact -Amy Judd,0.57142305,digipa-high-impact -Art Brenner,0.42619684,digipa-low-impact -Travis Louie,0.43916368,digipa-low-impact -James Jean,0.5457318,digipa-high-impact -Ewald Rübsamen,0.57083976,fineart -Donato Giancola,0.57052535,fineart -Carl Arnold Gonzenbach,0.5703996,fineart -Bastien Lecouffe-Deharme,0.5201288,fineart -Howard Chandler Christy,0.5702813,nudity -Dean Cornwell,0.56977296,fineart -Don Maitz,0.4743015,fineart -James Montgomery Flagg,0.56974065,fineart -Andreas Levers,0.42125136,digipa-low-impact -Edgar Schofield Baum,0.56965977,fineart -Alan Parry,0.5694952,digipa-high-impact -An Zhengwen,0.56942475,fareast -Alayna Lemmer,0.48293802,fineart -Edward Marshall Boehm,0.5530143,fineart -Henri Biva,0.54013556,nudity -Fiona Rae,0.4646715,digipa-low-impact -Elizabeth Jane Lloyd,0.5688463,digipa-high-impact -Franklin Carmichael,0.5687844,digipa-high-impact -Dionisius,0.56875896,fineart -Edwin Georgi,0.56868523,fineart -Jenny Saville,0.5686633,fineart -Ernest Hébert,0.56859314,fineart -Stephan Martiniere,0.56856346,digipa-high-impact -Huang Binhong,0.56841767,fineart -August Lemmer,0.5683548,fineart -Camille Bouvagne,0.5678048,fineart -Olga Skomorokhova,0.39401102,digipa-low-impact -Sacha Goldberger,0.5675477,digipa-high-impact -Hilda Annetta Walker,0.5675261,digipa-high-impact -Harvey Pratt,0.51314723,digipa-med-impact -Jean Bourdichon,0.5670543,fineart -Noriyoshi Ohrai,0.56690073,fineart -Kadir Nelson,0.5669006,n -Ilya Ostroukhov,0.5668801,fineart -Eugène Brands,0.56681967,fineart -Achille Leonardi,0.56674325,fineart -Franz Cižek,0.56670356,fineart -George Paul Chalmers,0.5665988,digipa-high-impact -Serge Marshennikov,0.5665971,digipa-high-impact -Mike Worrall,0.56641084,fineart -Dirck van Delen,0.5661764,fineart -Peter Andrew Jones,0.5661655,fineart -Rafael Albuquerque,0.56541103,fineart -Daniel Buren,0.5654043,fineart -Giuseppe Grisoni,0.5432699,fineart -George Fiddes Watt,0.55861616,fineart -Stan Lee,0.5651268,digipa-high-impact -Dorning Rasbotham,0.56511617,fineart -Albert Lynch,0.56497896,fineart -Lorenz Hideyoshi,0.56494075,fineart -Fenghua Zhong,0.56492203,fareast -Caroline Lucy Scott,0.49190843,digipa-med-impact -Victoria Crowe,0.5647996,digipa-high-impact -Hasegawa Settan,0.5647092,fareast -Dennis H. Farber,0.56453323,digipa-high-impact -Dick Bickenbach,0.5644289,fineart -Art Frahm,0.56439924,fineart -Edith Edmonds,0.5643151,fineart -Alfred Heber Hutty,0.56419206,fineart -Henry Tonks,0.56410825,fineart -Peter Howson,0.5640759,fineart -Albert Dorne,0.56395364,fineart -Arthur Adams,0.5639404,fineart -Bernt Tunold,0.56383425,digipa-high-impact -Gianluca Foli,0.5637317,digipa-high-impact -Vittorio Matteo Corcos,0.5636767,fineart -Béla Iványi-Grünwald,0.56355745,nudity -Feng Zhu,0.5634973,fineart -Sam Kieth,0.47251505,digipa-low-impact -Charles Crodel,0.5633834,fineart -Elsie Henderson,0.56310076,digipa-high-impact -George Earl Ortman,0.56295705,fineart -Tari Márk Dávid,0.562937,fineart -Betty Merken,0.56281745,digipa-high-impact -Cecile Walton,0.46672013,digipa-low-impact -Bracha L. Ettinger,0.56237936,fineart -Ken Fairclough,0.56230986,digipa-high-impact -Phil Koch,0.56224954,digipa-high-impact -George Pirie,0.56213045,digipa-high-impact -Chad Knight,0.56194013,digipa-high-impact -Béla Kondor,0.5427164,digipa-high-impact -Barclay Shaw,0.53689134,digipa-high-impact -Tim Hildebrandt,0.47194147,fineart -Hermann Rüdisühli,0.56104004,digipa-high-impact -Ian McQue,0.5342066,digipa-high-impact -Yanjun Cheng,0.5607171,fineart -Heinrich Hofmann,0.56060636,fineart -Henry Raleigh,0.5605958,fineart -Ernest Buckmaster,0.5605704,fineart -Charles Ricketts,0.56055415,fineart -Juergen Teller,0.56051147,digipa-high-impact -Auguste Mambour,0.5604873,fineart -Sean Yoro,0.5601486,digipa-high-impact -Sheilah Beckett,0.55995446,digipa-high-impact -Eugene Tertychnyi,0.5598978,fineart -Dr. Seuss,0.5597466,c -Adolf Wölfli,0.5372333,digipa-high-impact -Enrique Tábara,0.559323,fineart -Dionisio Baixeras Verdaguer,0.5590695,fineart -Aleksander Gierymski,0.5590013,fineart -Augustus Dunbier,0.55872476,fineart -Adolf Born,0.55848217,fineart -Chris Turnham,0.5584234,digipa-high-impact -James C Christensen,0.55837405,fineart -Daphne Fedarb,0.5582459,digipa-high-impact -Andre Kohn,0.5581832,special -Ron Mueck,0.5581811,nudity -Glenn Fabry,0.55786383,fineart -Elizabeth Polunin,0.5578102,digipa-high-impact -Charles S. Kaelin,0.5577954,fineart -Arthur Radebaugh,0.5577016,fineart -Ai Yazawa,0.55768114,fareast -Charles Roka,0.55762553,fineart -Ai Weiwei,0.5576034,digipa-high-impact -Dorothy Bradford,0.55760014,digipa-high-impact -Alfred Leslie,0.557555,fineart -Heinrich Herzig,0.5574423,fineart -Eliot Hodgkin,0.55740607,digipa-high-impact -Albert Kotin,0.55737317,fineart -Carlo Carlone,0.55729353,fineart -Chen Rong,0.5571221,fineart -Ikuo Hirayama,0.5570225,digipa-high-impact -Edward Corbett,0.55701995,nudity -Eugeniusz Żak,0.556925,nudity -Ettore Tito,0.556875,fineart -Helene Knoop,0.5567731,fineart -Amanda Sage,0.37731662,fareast -Annick Bouvattier,0.54647046,fineart -Harvey Dunn,0.55663586,fineart -Hans Sandreuter,0.5562575,digipa-high-impact -Ruan Jia,0.5398549,special -Anton Räderscheidt,0.55618906,fineart -Tyler Shields,0.4081434,digipa-low-impact -Darek Zabrocki,0.49975997,digipa-med-impact -Frank Montague Moore,0.5556432,fineart -Greg Staples,0.5555332,fineart -Endre Bálint,0.5553731,fineart -Augustus Vincent Tack,0.5136602,fineart -Marc Simonetti,0.48602036,fineart -Carlo Randanini,0.55493265,digipa-high-impact -Diego Dayer,0.5549119,fineart -Kelly Freas,0.55476534,fineart -Thomas Saliot,0.5139967,digipa-med-impact -Gijsbert d'Hondecoeter,0.55455256,fineart -Walter Kim,0.554521,digipa-high-impact -Francesco Cozza,0.5155097,digipa-med-impact -Bill Watterson,0.5542879,digipa-high-impact -Mark Keathley,0.4824056,fineart -Béni Ferenczy,0.55405354,digipa-high-impact -Amadou Opa Bathily,0.5536976,n -Giuseppe Antonio Petrini,0.55340284,fineart -Enzo Cucchi,0.55331933,digipa-high-impact -Adolf Schrödter,0.55316544,fineart -George Benjamin Luks,0.548566,fineart -Glenys Cour,0.55304,digipa-high-impact -Andrew Robertson,0.5529603,digipa-high-impact -Claude Rogers,0.55272067,digipa-high-impact -Alexandre Antigna,0.5526737,fineart -Aimé Barraud,0.55265915,digipa-high-impact -György Vastagh,0.55258965,fineart -Bruce Nauman,0.55257386,digipa-high-impact -Benjamin Block,0.55251944,digipa-high-impact -Gonzalo Endara Crow,0.552346,digipa-high-impact -Dirck de Bray,0.55221736,fineart -Gerald Kelley,0.5521059,digipa-high-impact -Dave Gibbons,0.5520954,digipa-high-impact -Béla Nagy Abodi,0.5520624,digipa-high-impact -Faith 47,0.5517006,digipa-high-impact -Anna Razumovskaya,0.5229187,digipa-med-impact -Archibald Robertson,0.55129635,digipa-high-impact -Louise Dahl-Wolfe,0.55120385,digipa-high-impact -Simon Bisley,0.55119276,digipa-high-impact -Eric Fischl,0.55107886,fineart -Hu Zaobin,0.5510481,fareast -Béla Pállik,0.5507963,digipa-high-impact -Eugene J. Martin,0.55078864,fineart -Friedrich Gauermann,0.55063415,fineart -Fritz Baumann,0.5341434,fineart -Michal Lisowski,0.5505639,fineart -Paolo Roversi,0.5503342,digipa-high-impact -Andrew Atroshenko,0.55009747,fineart -Gyula Derkovits,0.5500315,fineart -Hugh Adam Crawford,0.55000615,digipa-high-impact -Béla Apáti Abkarovics,0.5499799,digipa-high-impact -Paul Chadeisson,0.389151,digipa-low-impact -Aurél Bernáth,0.54968774,fineart -Albert Henry Krehbiel,0.54952574,fineart -Piet Hein Eek,0.54918796,digipa-high-impact -Yoshitaka Amano,0.5491855,fareast -Antonio Rotta,0.54909515,fineart -Józef Mehoffer,0.50760424,fineart -Donald Sherwood,0.5490415,digipa-high-impact -Catrin G Grosse,0.5489286,digipa-high-impact -Arthur Webster Emerson,0.5478842,fineart -Incarcerated Jerkfaces,0.5488423,digipa-high-impact -Emanuel Büchel,0.5487217,fineart -Andrew Loomis,0.54854584,fineart -Charles Hopkinson,0.54853606,fineart -Gabor Szikszai,0.5485203,digipa-high-impact -Archibald Standish Hartrick,0.54850936,digipa-high-impact -Aleksander Orłowski,0.546705,nudity -Hans Hinterreiter,0.5483628,fineart -Fred Williams,0.54544824,fineart -Fred A. Precht,0.5481606,fineart -Camille Souter,0.5213742,fineart -Emil Fuchs,0.54807395,fineart -Francesco Bonsignori,0.5478936,fineart -H. R. (Hans Ruedi) Giger,0.547799,fineart -Harriet Zeitlin,0.5477388,digipa-high-impact -Christian Jane Fergusson,0.5396168,fineart -Edward Kemble,0.5476892,fineart -Bernard Aubertin,0.5475396,fineart -Augustyn Mirys,0.5474162,fineart -Alejandro Burdisio,0.47482288,special -Erin Hanson,0.4343264,digipa-low-impact -Amalia Lindegren,0.5471987,digipa-high-impact -Alberto Seveso,0.47735062,fineart -Bartholomeus Strobel,0.54703736,fineart -Jim Davis,0.54703003,digipa-high-impact -Antony Gormley,0.54696125,digipa-high-impact -Charles Marion Russell,0.54696095,fineart -George B. Sutherland,0.5467901,fineart -Almada Negreiros,0.54670584,fineart -Edward Armitage,0.54358315,fineart -Bruno Walpoth,0.546167,digipa-high-impact -Richard Hamilton,0.5461275,nudity -Charles Harold Davis,0.5460415,digipa-high-impact -Fernand Verhaegen,0.54601514,fineart -Bernard Meninsky,0.5302034,digipa-high-impact -Fede Galizia,0.5456873,digipa-high-impact -Alfred Kelsner,0.5455753,nudity -Fritz Puempin,0.5452847,fineart -Alfred Charles Parker,0.54521024,fineart -Ahmed Yacoubi,0.544767,digipa-high-impact -Arthur B. Carles,0.54447794,fineart -Alice Prin,0.54435575,digipa-high-impact -Carl Gustaf Pilo,0.5443212,digipa-high-impact -Ross Tran,0.5259248,special -Hideyuki Kikuchi,0.544193,fareast -Art Fitzpatrick,0.49847245,fineart -Cherryl Fountain,0.5440454,fineart -Skottie Young,0.5440119,cartoon -NC Wyeth,0.54382974,digipa-high-impact -Rudolf Freund,0.5437342,fineart -Mort Kunstler,0.5433619,digipa-high-impact -Ben Goossens,0.53002644,digipa-high-impact -Andreas Rocha,0.49621177,special -Gérard Ernest Schneider,0.5429964,fineart -Francesco Filippini,0.5429598,digipa-high-impact -Alejandro Jodorowsky,0.5429065,digipa-high-impact -Friedrich Traffelet,0.5428817,fineart -Honor C. Appleton,0.5428735,digipa-high-impact -Jason A. Engle,0.542821,fineart -Henry Otto Wix,0.54271996,fineart -Gregory Manchess,0.54270375,fineart -Ann Stookey,0.54269934,digipa-high-impact -Henryk Rodakowski,0.542589,fineart -Albert Welti,0.5425134,digipa-high-impact -Gerard Houckgeest,0.5424413,digipa-high-impact -Dorothy Hood,0.54226196,digipa-high-impact -Frank Schoonover,0.51056194,fineart -Erlund Hudson,0.5422107,digipa-high-impact -Alexander Litovchenko,0.54210097,fineart -Sakai Hōitsu,0.5420294,digipa-high-impact -Benito Quinquela Martín,0.54194224,fineart -David Watson Stevenson,0.54191554,fineart -Ann Thetis Blacker,0.5416629,digipa-high-impact -Frank DuMond,0.51004076,digipa-med-impact -David Dougal Williams,0.5410126,digipa-high-impact -Robert Mcginnis,0.54098356,fineart -Ernest Briggs,0.5408636,fineart -Ferenc Joachim,0.5408625,fineart -Carlos Saenz de Tejada,0.47332364,digipa-low-impact -David Burton-Richardson,0.49659324,digipa-med-impact -Ernest Heber Thompson,0.54039246,digipa-high-impact -Albert Bertelsen,0.54038215,nudity -Giorgio Giulio Clovio,0.5403708,fineart -Eugene Leroy,0.54019785,digipa-high-impact -Anna Findlay,0.54018176,digipa-high-impact -Roy Gjertson,0.54012,digipa-high-impact -Charmion von Wiegand,0.5400893,fineart -Arnold Bronckhorst,0.526247,fineart -Boris Vallejo,0.487253,fineart -Adélaïde Victoire Hall,0.539939,fineart -Earl Norem,0.5398575,fineart -Sanford Kossin,0.53977877,digipa-high-impact -Aert de Gelder,0.519166,digipa-med-impact -Carl Eugen Keel,0.539739,digipa-high-impact -Francis Bourgeois,0.5397272,digipa-high-impact -Bojan Jevtic,0.41141546,fineart -Edward Avedisian,0.5393925,fineart -Gao Xiang,0.5392419,fareast -Charles Hinman,0.53911865,digipa-high-impact -Frits Van den Berghe,0.53896487,fineart -Carlo Martini,0.5384833,digipa-high-impact -Elina Karimova,0.5384318,digipa-high-impact -Anto Carte,0.4708289,digipa-low-impact -Andrey Yefimovich Martynov,0.537721,fineart -Frances Jetter,0.5376904,fineart -Yuri Ivanovich Pimenov,0.5342793,fineart -Gaston Anglade,0.537608,digipa-high-impact -Albert Swinden,0.5375844,fineart -Bob Byerley,0.5375774,fineart -A.B. Frost,0.5375025,fineart -Jaya Suberg,0.5372893,digipa-high-impact -Josh Keyes,0.53654516,digipa-high-impact -Juliana Huxtable,0.5364195,n -Everett Warner,0.53641814,digipa-high-impact -Hugh Kretschmer,0.45171157,digipa-low-impact -Arnold Blanch,0.535774,fineart -Ryan McGinley,0.53572595,digipa-high-impact -Alfons Karpiński,0.53564656,fineart -George Aleef,0.5355317,digipa-high-impact -Hal Foster,0.5351446,fineart -Stuart Immonen,0.53501946,digipa-high-impact -Craig Thompson,0.5346844,digipa-high-impact -Bartolomeo Vivarini,0.53465015,fineart -Hermann Feierabend,0.5346168,digipa-high-impact -Antonio Donghi,0.4610982,digipa-low-impact -Adonna Khare,0.4858036,digipa-med-impact -James Stokoe,0.5015107,digipa-med-impact -Agustín Fernández,0.53403986,fineart -Germán Londoño,0.5338712,fineart -Emmanuelle Moureaux,0.5335641,digipa-high-impact -Conrad Marca-Relli,0.5148334,digipa-med-impact -Gyula Batthyány,0.5332407,fineart -Francesco Raibolini,0.53314835,fineart -Apelles,0.5166026,fineart -Marat Latypov,0.45811993,fineart -Andrei Markin,0.5328752,fineart -Einar Hakonarson,0.5328311,digipa-high-impact -Beatrice Huntington,0.5328165,digipa-high-impact -Coppo di Marcovaldo,0.5327443,fineart -Gregorio Prestopino,0.53250784,fineart -A.D.M. Cooper,0.53244877,digipa-high-impact -Horatio McCulloch,0.53244334,digipa-high-impact -Wes Anderson,0.5318741,digipa-high-impact -Moebius,0.53178746,digipa-high-impact -Gerard Soest,0.53160626,fineart -Charles Ellison,0.53152347,digipa-high-impact -Wojciech Ostrycharz,0.5314213,fineart -Doug Chiang,0.5313724,fineart -Anne Savage,0.5310638,digipa-high-impact -Cor Melchers,0.53099334,fineart -Gordon Browne,0.5308195,digipa-high-impact -Augustus Earle,0.49196815,fineart -Carlos Francisco Chang Marín,0.5304734,fineart -Larry Elmore,0.53032553,fineart -Adolf Hölzel,0.5303149,fineart -David Ligare,0.5301894,fineart -Jan Luyken,0.52985555,fineart -Earle Bergey,0.5298525,fineart -David Ramsay Hay,0.52974963,digipa-high-impact -Alfred East,0.5296565,digipa-high-impact -A. R. Middleton Todd,0.50988734,fineart -Giorgio De Vincenzi,0.5291678,fineart -Hugh William Williams,0.5291014,digipa-high-impact -Erwin Bowien,0.52895796,digipa-high-impact -Victor Adame Minguez,0.5288686,fineart -Yoji Shinkawa,0.5287015,anime -Clara Weaver Parrish,0.5284487,digipa-high-impact -Albert Eckhout,0.5284096,fineart -Dorothy Coke,0.5282345,digipa-high-impact -Jerzy Duda-Gracz,0.5279943,digipa-high-impact -Byron Galvez,0.39178842,fareast -Alson S. Clark,0.5278568,digipa-high-impact -Adolf Ulric Wertmüller,0.5278296,digipa-high-impact -Bruce Coville,0.5277226,digipa-high-impact -Gong Kai,0.5276811,digipa-high-impact -Andréi Arinouchkine,0.52763486,digipa-high-impact -Florence Engelbach,0.5273161,digipa-high-impact -Brian Froud,0.5270276,fineart -Charles Thomson,0.5270127,digipa-high-impact -Bessie Wheeler,0.5269164,digipa-high-impact -Anton Lehmden,0.5268611,fineart -Emilia Wilk,0.5264961,fineart -Carl Eytel,0.52646196,digipa-high-impact -Alfred Janes,0.5264481,digipa-high-impact -Julie Bell,0.49962538,fineart -Eugenio de Arriba,0.52613926,digipa-high-impact -Samuel and Joseph Newsom,0.52595663,digipa-high-impact -Hans Falk,0.52588874,digipa-high-impact -Guillermo del Toro,0.52565175,digipa-high-impact -Félix Arauz,0.52555984,digipa-high-impact -Gyula Basch,0.52524436,digipa-high-impact -Haroon Mirza,0.5252279,digipa-high-impact -Du Jin,0.5249934,digipa-med-impact -Harry Shoulberg,0.5249456,digipa-med-impact -Arie Smit,0.5249027,fineart -Ahmed Karahisari,0.4259451,digipa-low-impact -Brian and Wendy Froud,0.5246335,fineart -E. William Gollings,0.52461207,digipa-med-impact -Bo Bartlett,0.51341593,digipa-med-impact -Hans Burgkmair,0.52416867,digipa-med-impact -David Macaulay,0.5241233,digipa-med-impact -Benedetto Caliari,0.52370214,digipa-med-impact -Eliott Lilly,0.5235398,digipa-med-impact -Vincent Tanguay,0.48578292,digipa-med-impact -Ada Hill Walker,0.52207166,fineart -Christopher Wood,0.49360397,digipa-med-impact -Kris Kuksi,0.43938053,digipa-low-impact -Chen Yifei,0.5217867,fineart -Margaux Valonia,0.5217782,digipa-med-impact -Antoni Pitxot,0.40582713,digipa-low-impact -Jhonen Vasquez,0.5216471,digipa-med-impact -Emilio Grau Sala,0.52156484,fineart -Henry B. Christian,0.52153796,fineart -Jacques Nathan-Garamond,0.52144086,digipa-med-impact -Eddie Mendoza,0.4949638,digipa-med-impact -Grzegorz Rutkowski,0.48906532,special -Beeple,0.40085253,digipa-low-impact -Giorgio Cavallon,0.5209209,digipa-med-impact -Godfrey Blow,0.52062386,digipa-med-impact -Gabriel Dawe,0.5204431,fineart -Emile Lahner,0.5202367,digipa-med-impact -Steve Dillon,0.5201676,digipa-med-impact -Lee Quinones,0.4626683,digipa-low-impact -Hale Woodruff,0.52000225,digipa-med-impact -Tom Hammick,0.5032626,digipa-med-impact -Hamilton Sloan,0.5197798,digipa-med-impact -Caesar Andrade Faini,0.51971483,digipa-med-impact -Sam Spratt,0.48991,digipa-med-impact -Chris Cold,0.4753577,fineart -Alejandro Obregón,0.5190562,digipa-med-impact -Dan Flavin,0.51901346,digipa-med-impact -Arthur Sarnoff,0.5189428,fineart -Elenore Abbott,0.5187141,digipa-med-impact -Andrea Kowch,0.51822996,digipa-med-impact -Demetrios Farmakopoulos,0.5181248,digipa-med-impact -Alexis Grimou,0.41958088,digipa-low-impact -Lesley Vance,0.5177536,digipa-med-impact -Gyula Aggházy,0.517747,fineart -Georgina Hunt,0.46105456,digipa-low-impact -Christian W. Staudinger,0.4684662,digipa-low-impact -Abraham Begeyn,0.5172538,digipa-med-impact -Charles Mozley,0.5171356,digipa-med-impact -Elias Ravanetti,0.38719344,digipa-low-impact -Herman van Swanevelt,0.5168748,digipa-med-impact -David Paton,0.4842217,digipa-med-impact -Hans Werner Schmidt,0.51671976,digipa-med-impact -Bob Ross,0.51628315,fineart -Sou Fujimoto,0.5162528,fareast -Balcomb Greene,0.5162045,digipa-med-impact -Glen Angus,0.51609933,digipa-med-impact -Buckminster Fuller,0.51607454,digipa-med-impact -Andrei Ryabushkin,0.5158933,fineart -Almeida Júnior,0.515856,digipa-med-impact -Tim White,0.4182697,digipa-low-impact -Hans Beat Wieland,0.51553553,digipa-med-impact -Jakub Różalski,0.5154904,digipa-med-impact -John Whitcomb,0.51523805,digipa-med-impact -Dorothy King,0.5150925,digipa-med-impact -Richard S. Johnson,0.51500344,fineart -Aniello Falcone,0.51475304,digipa-med-impact -Henning Jakob Henrik Lund,0.5147134,c -Robert M Cunningham,0.5144858,digipa-med-impact -Nick Knight,0.51447505,digipa-med-impact -David Chipperfield,0.51424,digipa-med-impact -Bartolomeo Cesi,0.5136737,digipa-med-impact -Bettina Heinen-Ayech,0.51334465,digipa-med-impact -Annabel Kidston,0.51327646,digipa-med-impact -Charles Schridde,0.51308405,digipa-med-impact -Samuel Earp,0.51305825,digipa-med-impact -Eugene Montgomery,0.5128343,digipa-med-impact -Alfred Parsons,0.5127445,digipa-med-impact -Anton Möller,0.5127209,digipa-med-impact -Craig Davison,0.499598,special -Cricorps Grégoire,0.51267076,fineart -Celia Fiennes,0.51266706,digipa-med-impact -Raymond Swanland,0.41350424,fineart -Howard Knotts,0.5122062,digipa-med-impact -Helmut Federle,0.51201206,digipa-med-impact -Tyler Edlin,0.44028252,digipa-high-impact -Elwood H. Smith,0.5119027,digipa-med-impact -Ralph Horsley,0.51142794,fineart -Alexander Ivanov,0.4539051,digipa-low-impact -Cedric Peyravernay,0.4200587,digipa-low-impact -Annabel Eyres,0.51136214,digipa-med-impact -Zack Snyder,0.51129746,digipa-med-impact -Gentile Bellini,0.511102,digipa-med-impact -Giovanni Pelliccioli,0.4868688,digipa-med-impact -Fikret Muallâ Saygı,0.510694,digipa-med-impact -Bauhaus,0.43454266,digipa-low-impact -Charles Williams,0.510406,digipa-med-impact -Georg Arnold-Graboné,0.5103381,digipa-med-impact -Fedot Sychkov,0.47935224,digipa-med-impact -Alberto Magnelli,0.5103212,digipa-med-impact -Aloysius O'Kelly,0.5102891,digipa-med-impact -Alexander McQueen,0.5101986,digipa-med-impact -Cam Sykes,0.510071,digipa-med-impact -George Lucas,0.510038,digipa-med-impact -Eglon van der Neer,0.5099339,digipa-med-impact -Christian August Lorentzen,0.50989646,digipa-med-impact -Eleanor Best,0.50966686,digipa-med-impact -Terry Redlin,0.474244,fineart -Ken Kelly,0.4304738,fineart -David Eugene Henry,0.48173362,fineart -Shin Jeongho,0.5092497,fareast -Flora Borsi,0.5091922,digipa-med-impact -Berndnaut Smilde,0.50864,digipa-med-impact -Art of Brom,0.45828784,fineart -Ernő Tibor,0.50851977,digipa-med-impact -Ancell Stronach,0.5084514,digipa-med-impact -Helen Thomas Dranga,0.45412368,digipa-low-impact -Anita Malfatti,0.5080986,digipa-med-impact -Arnold Brügger,0.5080749,digipa-med-impact -Edward Ben Avram,0.50778764,digipa-med-impact -Antonio Ciseri,0.5073538,fineart -Alyssa Monks,0.50734174,digipa-med-impact -Chen Zhen,0.5071876,digipa-med-impact -Francis Helps,0.50707847,digipa-med-impact -Georg Karl Pfahler,0.50700235,digipa-med-impact -Henry Woods,0.506811,digipa-med-impact -Barbara Greg,0.50674164,digipa-med-impact -Guan Daosheng,0.506712,fareast -Guy Billout,0.5064906,digipa-med-impact -Basuki Abdullah,0.50613165,digipa-med-impact -Thomas Visscher,0.5059943,digipa-med-impact -Edward Simmons,0.50598735,digipa-med-impact -Arabella Rankin,0.50572735,digipa-med-impact -Lady Pink,0.5056634,digipa-high-impact -Christopher Williams,0.5052288,digipa-med-impact -Fuyuko Matsui,0.5051116,fareast -Edward Baird,0.5049874,digipa-med-impact -Georges Stein,0.5049069,digipa-med-impact -Alex Alemany,0.43974748,digipa-low-impact -Emanuel Schongut,0.5047326,digipa-med-impact -Hans Bol,0.5045265,digipa-med-impact -Kurzgesagt,0.5043725,digipa-med-impact -Harald Giersing,0.50410193,digipa-med-impact -Antonín Slavíček,0.5040368,fineart -Carl Rahl,0.5040115,digipa-med-impact -Etienne Delessert,0.5037818,fineart -Americo Makk,0.5034161,digipa-med-impact -Fernand Pelez,0.5027561,digipa-med-impact -Alexey Merinov,0.4469615,digipa-low-impact -Caspar Netscher,0.5019529,digipa-med-impact -Walt Disney,0.50178146,digipa-med-impact -Qian Xuan,0.50150526,fareast -Geoffrey Dyer,0.50120556,digipa-med-impact -Andre Norton,0.5007602,digipa-med-impact -Daphne McClure,0.5007391,digipa-med-impact -Dieric Bouts,0.5005882,fineart -Aguri Uchida,0.5005107,fareast -Hugo Scheiber,0.50004864,digipa-med-impact -Kenne Gregoire,0.46421963,digipa-low-impact -Wolfgang Tillmans,0.4999767,fineart -Carl-Henning Pedersen,0.4998986,digipa-med-impact -Alison Debenham,0.4998683,digipa-med-impact -Eppo Doeve,0.49975222,digipa-med-impact -Christen Købke,0.49961317,digipa-med-impact -Aron Demetz,0.49895018,digipa-med-impact -Alesso Baldovinetti,0.49849576,digipa-med-impact -Jimmy Lawlor,0.4475271,fineart -Carl Walter Liner,0.49826378,fineart -Gwenny Griffiths,0.45598924,digipa-low-impact -David Cooke Gibson,0.4976222,digipa-med-impact -Howard Butterworth,0.4974621,digipa-med-impact -Bob Thompson,0.49743804,fineart -Enguerrand Quarton,0.49711192,fineart -Abdel Hadi Al Gazzar,0.49631482,digipa-med-impact -Gu Zhengyi,0.49629828,digipa-med-impact -Aleksander Kotsis,0.4953621,digipa-med-impact -Alexander Sharpe Ross,0.49519226,digipa-med-impact -Carlos Enríquez Gómez,0.49494863,digipa-med-impact -Abed Abdi,0.4948855,digipa-med-impact -Elaine Duillo,0.49474388,digipa-med-impact -Anne Said,0.49473995,digipa-med-impact -Istvan Banyai,0.4947369,digipa-med-impact -Bouchta El Hayani,0.49455142,digipa-med-impact -Chinwe Chukwuogo-Roy,0.49445248,n -George Claessen,0.49412063,digipa-med-impact -Axel Törneman,0.49401706,digipa-med-impact -Avigdor Arikha,0.49384058,digipa-med-impact -Gloria Stoll Karn,0.4937976,digipa-med-impact -Alfredo Volpi,0.49367586,digipa-med-impact -Raffaello Sanizo,0.49365884,digipa-med-impact -Jeff Easley,0.49344411,digipa-med-impact -Aileen Eagleton,0.49318358,digipa-med-impact -Gaetano Sabatini,0.49307147,digipa-med-impact -Bertalan Pór,0.4930132,digipa-med-impact -Alfred Jensen,0.49291304,digipa-med-impact -Huang Guangjian,0.49286693,fareast -Emil Ferris,0.49282396,digipa-med-impact -Derek Chittock,0.492694,digipa-med-impact -Alonso Vázquez,0.49205148,digipa-med-impact -Kelly Sue Deconnick,0.4919476,digipa-med-impact -Clive Madgwick,0.4749857,fineart -Edward George Handel Lucas,0.49166748,digipa-med-impact -Dorothea Braby,0.49161923,digipa-med-impact -Sangyeob Park,0.49150884,fareast -Heinz Edelman,0.49140438,digipa-med-impact -Mark Seliger,0.4912073,digipa-med-impact -Camilo Egas,0.4586727,digipa-low-impact -Craig Mullins,0.49085408,fineart -Dong Kingman,0.49063343,digipa-med-impact -Douglas Robertson Bisset,0.49031347,digipa-med-impact -Blek Le Rat,0.49008566,digipa-med-impact -Anton Ažbe,0.48984748,fineart -Olafur Eliasson,0.48971075,digipa-med-impact -Elinor Proby Adams,0.48967826,digipa-med-impact -Cándido López,0.48915705,digipa-med-impact -D. Howard Hitchcock,0.48902267,digipa-med-impact -Cheng Jiasui,0.48889247,fareast -Jean Nouvel,0.4888183,digipa-med-impact -Bill Gekas,0.48848945,digipa-med-impact -Hermione Hammond,0.48845994,digipa-med-impact -Fernando Gerassi,0.48841453,digipa-med-impact -Frank Barrington Craig,0.4883762,digipa-med-impact -A. B. Jackson,0.4883623,digipa-med-impact -Bernie D’Andrea,0.48813275,digipa-med-impact -Clarice Beckett,0.487809,digipa-med-impact -Dosso Dossi,0.48775777,digipa-med-impact -Donald Roller Wilson,0.48767656,digipa-med-impact -Ernest William Christmas,0.4876317,digipa-med-impact -Aleksandr Gerasimov,0.48736423,digipa-med-impact -Edward Clark,0.48703307,digipa-med-impact -Georg Schrimpf,0.48697302,digipa-med-impact -John Wilhelm,0.48696536,digipa-med-impact -Aries Moross,0.4863676,digipa-med-impact -Bill Lewis,0.48635158,digipa-med-impact -Huang Ji,0.48611963,fareast -F. Scott Hess,0.43634564,fineart -Gao Qipei,0.4860631,fareast -Albert Tucker,0.4854299,digipa-med-impact -Barbara Balmer,0.48528513,fineart -Anne Ryan,0.48511976,digipa-med-impact -Helen Edwards,0.48484707,digipa-med-impact -Alexander Bogen,0.48421195,digipa-med-impact -David Annand,0.48418126,digipa-med-impact -Du Qiong,0.48414314,fareast -Fred Cress,0.4837878,digipa-med-impact -David B. Mattingly,0.48370445,digipa-med-impact -Hristofor Žefarović,0.4837008,digipa-med-impact -Wim Wenders,0.44484183,digipa-low-impact -Alexander Fedosav,0.48360944,digipa-med-impact -Anne Rigney,0.48357943,digipa-med-impact -Bertalan Karlovszky,0.48338628,digipa-med-impact -George Frederick Harris,0.4833259,fineart -Toshiharu Mizutani,0.48315164,fareast -David McClellan,0.39739317,digipa-low-impact -Eugeen Van Mieghem,0.48270774,digipa-med-impact -Alexei Harlamoff,0.48255378,digipa-med-impact -Jeff Legg,0.48249072,digipa-med-impact -Elizabeth Murray,0.48227608,digipa-med-impact -Hugo Heyrman,0.48213717,digipa-med-impact -Adrian Paul Allinson,0.48211843,digipa-med-impact -Altoon Sultan,0.4820177,digipa-med-impact -Alice Mason,0.48188528,fareast -Harriet Powers,0.48181778,digipa-med-impact -Aaron Bohrod,0.48175076,digipa-med-impact -Chris Saunders,0.41429797,digipa-low-impact -Clara Miller Burd,0.47797233,digipa-med-impact -David G. Sorensen,0.38101727,digipa-low-impact -Iwan Baan,0.4806739,digipa-med-impact -Anatoly Metlan,0.48020265,digipa-med-impact -Alfons von Czibulka,0.4801954,digipa-med-impact -Amedee Ozenfant,0.47950014,digipa-med-impact -Valerie Hegarty,0.47947168,digipa-med-impact -Hugo Anton Fisher,0.4793551,digipa-med-impact -Antonio Roybal,0.4792729,digipa-med-impact -Cui Zizhong,0.47902682,fareast -F Scott Hess,0.42582104,fineart -Julien Delval,0.47888556,digipa-med-impact -Marcin Jakubowski,0.4788583,digipa-med-impact -Anne Stokes,0.4786997,digipa-med-impact -David Palumbo,0.47632077,fineart -Hallsteinn Sigurðsson,0.47858906,digipa-med-impact -Mike Campau,0.47850558,digipa-med-impact -Giuseppe Avanzi,0.47846943,digipa-med-impact -Harry Morley,0.47836518,digipa-med-impact -Constance-Anne Parker,0.47832203,digipa-med-impact -Albert Keller,0.47825447,digipa-med-impact -Daniel Chodowiecki,0.47825167,digipa-med-impact -Alasdair Grant Taylor,0.47802624,digipa-med-impact -Maria Pascual Alberich,0.4779718,fineart -Rebeca Saray,0.41697127,digipa-low-impact -Ernő Bánk,0.47753686,digipa-med-impact -Shaddy Safadi,0.47724134,digipa-med-impact -André Castro,0.4771826,digipa-med-impact -Amiet Cuno,0.41975892,digipa-low-impact -Adi Granov,0.40670198,fineart -Allen Williams,0.47675848,digipa-med-impact -Anna Haifisch,0.47672725,digipa-med-impact -Clovis Trouille,0.47669724,digipa-med-impact -Jane Graverol,0.47655866,digipa-med-impact -Conroy Maddox,0.47645602,digipa-med-impact -Božidar Jakac,0.4763106,digipa-med-impact -George Morrison,0.47533786,digipa-med-impact -Douglas Bourgeois,0.47527707,digipa-med-impact -Cao Zhibai,0.47476804,fareast -Bradley Walker Tomlin,0.47462896,digipa-low-impact -Dave Dorman,0.46852386,fineart -Stevan Dohanos,0.47452107,fineart -John Howe,0.44144905,fineart -Fanny McIan,0.47406268,digipa-low-impact -Bholekar Srihari,0.47387534,digipa-low-impact -Giovanni Lanfranco,0.4737344,digipa-low-impact -Fred Marcellino,0.47346023,digipa-low-impact -Clyde Caldwell,0.47305286,fineart -Haukur Halldórsson,0.47275954,digipa-low-impact -Huang Gongwang,0.47269204,fareast -Brothers Grimm,0.47249007,digipa-low-impact -Ollie Hoff,0.47240657,digipa-low-impact -RHADS,0.4722166,digipa-low-impact -Constance Gordon-Cumming,0.47219282,digipa-low-impact -Anne Mccaffrey,0.4719924,digipa-low-impact -Henry Heerup,0.47190166,digipa-low-impact -Adrian Smith,0.4716923,digipa-high-impact -Harold Elliott,0.4714101,digipa-low-impact -Eric Peterson,0.47106332,digipa-low-impact -David Garner,0.47106326,digipa-low-impact -Edward Hicks,0.4708863,digipa-low-impact -Alfred Krupa,0.47052455,digipa-low-impact -Breyten Breytenbach,0.4699338,digipa-low-impact -Douglas Shuler,0.4695691,digipa-low-impact -Elaine Hamilton,0.46941522,digipa-low-impact -Kapwani Kiwanga,0.46917036,digipa-low-impact -Dan Scott,0.46897763,digipa-low-impact -Allan Brooks,0.46882123,digipa-low-impact -Ian Fairweather,0.46878594,digipa-low-impact -Arlington Nelson Lindenmuth,0.4683814,digipa-low-impact -Russell Ayto,0.4681503,digipa-low-impact -Allan Linder,0.46812692,digipa-low-impact -Bohumil Kubista,0.4679809,digipa-low-impact -Christopher Jin Baron,0.4677839,digipa-low-impact -Eero Snellman,0.46777654,digipa-low-impact -Christabel Dennison,0.4677633,digipa-low-impact -Amelia Peláez,0.46764764,digipa-low-impact -James Gurney,0.46740666,digipa-low-impact -Carles Delclaux Is,0.46734855,digipa-low-impact -George Papazov,0.42420334,digipa-low-impact -Mark Brooks,0.4672415,fineart -Anne Dunn,0.46722376,digipa-low-impact -Klaus Wittmann,0.4670704,fineart -Arvid Nyholm,0.46697336,digipa-low-impact -Georg Scholz,0.46674117,digipa-low-impact -David Spriggs,0.46671993,digipa-low-impact -Ernest Morgan,0.4665036,digipa-low-impact -Ella Guru,0.46619284,digipa-low-impact -Helen Berman,0.46614346,digipa-low-impact -Gen Paul,0.4658785,digipa-low-impact -Auseklis Ozols,0.46569023,digipa-low-impact -Amelia Robertson Hill,0.4654411,fineart -Jim Lee,0.46544096,digipa-low-impact -Anson Maddocks,0.46539295,digipa-low-impact -Chen Hong,0.46516004,fareast -Haddon Sundblom,0.46490777,digipa-low-impact -Eva Švankmajerová,0.46454152,digipa-low-impact -Antonio Cavallucci,0.4645282,digipa-low-impact -Herve Groussin,0.40050638,digipa-low-impact -Gwen Barnard,0.46400994,digipa-low-impact -Grace English,0.4638674,digipa-low-impact -Carl Critchlow,0.4636,digipa-low-impact -Ayshia Taşkın,0.463412,digipa-low-impact -Alison Watt,0.43141022,digipa-low-impact -Andre de Krayewski,0.4628024,digipa-low-impact -Hamish MacDonald,0.462645,digipa-low-impact -Ni Chuanjing,0.46254826,fareast -Frank Mason,0.46254665,digipa-low-impact -Steve Henderson,0.43113405,fineart -Eileen Aldridge,0.46210572,digipa-low-impact -Brad Rigney,0.28446302,digipa-low-impact -Ching Yeh,0.46177,fareast -Bertram Brooker,0.46176457,digipa-low-impact -Henry Bright,0.46150023,digipa-low-impact -Claire Dalby,0.46117848,digipa-low-impact -Brian Despain,0.41538632,digipa-low-impact -Anna Maria Barbara Abesch,0.4611045,digipa-low-impact -Bernardo Daddi,0.46088326,digipa-low-impact -Abraham Mintchine,0.46088243,digipa-high-impact -Alexander Carse,0.46078917,digipa-low-impact -Doc Hammer,0.46075988,digipa-low-impact -Yuumei,0.46072406,digipa-low-impact -Teophilus Tetteh,0.46064255,n -Bess Hamiti,0.46062252,digipa-low-impact -Ceferí Olivé,0.46058378,digipa-low-impact -Enrique Grau,0.46046937,digipa-low-impact -Eleanor Hughes,0.46007007,digipa-low-impact -Elizabeth Charleston,0.46001568,digipa-low-impact -Félix Ziem,0.45987016,digipa-low-impact -Eugeniusz Zak,0.45985222,digipa-low-impact -Dain Yoon,0.45977795,fareast -Gong Xian,0.4595083,digipa-low-impact -Flavia Blois,0.45950204,digipa-low-impact -Frederik Vermehren,0.45949826,digipa-low-impact -Gang Se-hwang,0.45937777,digipa-low-impact -Bjørn Wiinblad,0.45934483,digipa-low-impact -Alex Horley-Orlandelli,0.42623433,digipa-low-impact -Dr. Atl,0.459287,digipa-low-impact -Hu Jieqing,0.45889485,fareast -Amédée Ozenfant,0.4585215,digipa-low-impact -Warren Ellis,0.4584044,digipa-low-impact -Helen Dahm,0.45804346,digipa-low-impact -Anne Geddes,0.45785287,digipa-low-impact -Bikash Bhattacharjee,0.45775396,digipa-low-impact -Phil Foglio,0.457582,digipa-low-impact -Evelyn Abelson,0.4574563,digipa-low-impact -Alan Moore,0.4573369,digipa-low-impact -Josh Kao,0.45725146,fareast -Bertil Nilsson,0.45724383,digipa-low-impact -Hristofor Zhefarovich,0.457089,fineart -Edward Bailey,0.45659882,digipa-low-impact -Christopher Moeller,0.45648077,digipa-low-impact -Dóra Keresztes,0.4558745,fineart -Cory Arcangel,0.4558071,digipa-low-impact -Aleksander Kobzdej,0.45552525,digipa-low-impact -Tim Burton,0.45541722,digipa-high-impact -Chen Jiru,0.4553378,fareast -George Passantino,0.4552104,digipa-low-impact -Fuller Potter,0.4552072,digipa-low-impact -Warwick Globe,0.45516664,digipa-low-impact -Heinz Anger,0.45466962,digipa-low-impact -Elias Goldberg,0.45416242,digipa-low-impact -tokyogenso,0.45406622,fareast -Zeen Chin,0.45404464,digipa-low-impact -Albert Koetsier,0.45385844,fineart -Giuseppe Camuncoli,0.45377725,digipa-low-impact -Elsie Vera Cole,0.45377362,digipa-low-impact -Andreas Franke,0.4300047,digipa-low-impact -Constantine Andreou,0.4533816,digipa-low-impact -Elisabeth Collins,0.45337808,digipa-low-impact -Ted Nasmith,0.45302224,fineart -Antônio Parreiras,0.45269623,digipa-low-impact -Gwilym Prichard,0.45256525,digipa-low-impact -Fang Congyi,0.45240825,fareast -Huang Ding,0.45233482,fareast -Hans von Bartels,0.45200723,digipa-low-impact -Peter Elson,0.4121406,fineart -Fan Kuan,0.4513034,digipa-low-impact -Dean Roger,0.45112592,digipa-low-impact -Bernat Sanjuan,0.45074993,fareast -Fletcher Martin,0.45055175,digipa-low-impact -Gentile Tondino,0.45043385,digipa-low-impact -Ei-Q,0.45038772,digipa-low-impact -Chen Lin,0.45035738,fareast -Ted Wallace,0.4500007,digipa-low-impact -"Cornelisz Hendriksz Vroom, the Younger",0.4499252,digipa-low-impact -Alpo Jaakola,0.44981295,digipa-low-impact -Clark Voorhees,0.4495309,digipa-low-impact -Cleve Gray,0.449188,digipa-low-impact -Wolf Kahn,0.4489858,digipa-low-impact -Choi Buk,0.44892842,fareast -Frank Tinsley,0.4480373,digipa-low-impact -George Bell,0.44779524,digipa-low-impact -Fiona Stephenson,0.44761062,fineart -Carlos Trillo Name,0.4470371,digipa-low-impact -Jamie McKelvie,0.44696707,digipa-low-impact -Dennis Flanders,0.44673377,digipa-low-impact -Dulah Marie Evans,0.44662604,digipa-low-impact -Hans Schwarz,0.4463275,digipa-low-impact -Steve McCurry,0.44620228,digipa-low-impact -Bedwyr Williams,0.44616276,digipa-low-impact -Anton Graff,0.38569996,digipa-low-impact -Leticia Gillett,0.44578317,digipa-low-impact -Rafał Olbiński,0.44561762,digipa-low-impact -Artgerm,0.44555497,fineart -Adrienn Henczné Deák,0.445518,digipa-low-impact -Gu Hongzhong,0.4454906,fareast -Matt Groening,0.44518438,digipa-low-impact -Sue Bryce,0.4447164,digipa-low-impact -Armin Baumgarten,0.444061,digipa-low-impact -Araceli Gilbert,0.44399196,digipa-low-impact -Carey Morris,0.44388965,digipa-low-impact -Ignat Bednarik,0.4438085,digipa-low-impact -Frank Buchser,0.44373792,digipa-low-impact -Ben Zoeller,0.44368798,digipa-low-impact -Adam Szentpétery,0.4434548,fineart -Gene Davis,0.44343877,digipa-low-impact -Fei Danxu,0.4433627,fareast -Andrei Kolkoutine,0.44328922,digipa-low-impact -Bruce Onobrakpeya,0.42588046,n -Christoph Amberger,0.38912287,digipa-low-impact -"Fred Mitchell,",0.4432277,digipa-low-impact -Klaus Burgle,0.44295216,digipa-low-impact -Carl Hoppe,0.44270635,digipa-low-impact -Caroline Gotch,0.44263047,digipa-low-impact -Hans Mertens,0.44260004,digipa-low-impact -Mandy Disher,0.44219893,fineart -Sarah Lucas,0.4420507,digipa-low-impact -Sydney Edmunds,0.44198513,digipa-low-impact -Amos Ferguson,0.4418735,digipa-low-impact -Alton Tobey,0.4416385,digipa-low-impact -Clifford Ross,0.44139367,digipa-low-impact -Henric Trenk,0.4412782,digipa-low-impact -Claire Hummel,0.44119984,digipa-low-impact -Norman Foster,0.4411899,digipa-low-impact -Carmen Saldana,0.44076762,digipa-low-impact -Michael Whelan,0.4372847,digipa-low-impact -Carlos Berlanga,0.440354,digipa-low-impact -Gilles Beloeil,0.43997732,digipa-low-impact -Ashley Wood,0.4398396,digipa-low-impact -David Allan,0.43969798,digipa-low-impact -Mark Lovett,0.43922082,digipa-low-impact -Jed Henry,0.43882954,digipa-low-impact -Adam Bruce Thomson,0.43847767,digipa-low-impact -Horst Antes,0.4384303,digipa-low-impact -Fritz Glarner,0.43787453,digipa-low-impact -Harold McCauley,0.43760818,digipa-low-impact -Estuardo Maldonado,0.437594,digipa-low-impact -Dai Jin,0.4375449,fareast -Fabien Charuau,0.43688047,digipa-low-impact -Chica Macnab,0.4365166,digipa-low-impact -Jim Burns,0.3975072,digipa-low-impact -Santiago Calatrava,0.43651623,digipa-low-impact -Robert Maguire,0.40926617,digipa-low-impact -Cliff Childs,0.43611953,digipa-low-impact -Charles Martin,0.43582463,fareast -Elbridge Ayer Burbank,0.43572164,digipa-low-impact -Anita Kunz,0.4356005,digipa-low-impact -Colin Geller,0.43559563,digipa-low-impact -Allen Tupper True,0.43556124,digipa-low-impact -Jef Wu,0.43555313,digipa-low-impact -Jon McCoy,0.4147122,digipa-low-impact -Cedric Seaut,0.43521535,digipa-low-impact -Emily Shanks,0.43519047,digipa-low-impact -Andrew Whem,0.43512022,digipa-low-impact -Ibrahim Kodra,0.43471518,digipa-low-impact -Harrington Mann,0.4345901,digipa-low-impact -Jerry Siegel,0.43458986,digipa-low-impact -Howard Kanovitz,0.4345178,digipa-low-impact -Cicely Hey,0.43449926,digipa-low-impact -Ben Thompson,0.43436068,digipa-low-impact -Joe Bowler,0.43413073,digipa-low-impact -Lori Earley,0.43389612,digipa-low-impact -Arent Arentsz,0.43373522,digipa-low-impact -David Bailly,0.43371305,digipa-low-impact -Hans Arnold,0.4335214,digipa-low-impact -Constance Copeman,0.4334836,digipa-low-impact -Brent Heighton,0.4333118,fineart -Eric Taylor,0.43312082,digipa-low-impact -Aleksander Gine,0.4326849,digipa-low-impact -Alexander Johnston,0.4326589,digipa-low-impact -David Park,0.43235332,digipa-low-impact -Balázs Diószegi,0.432244,digipa-low-impact -Ed Binkley,0.43222216,digipa-low-impact -Eric Dinyer,0.4321258,digipa-low-impact -Susan Luo,0.43198025,fareast -Cedric Seaut (Keos Masons),0.4317356,digipa-low-impact -Lorena Alvarez Gómez,0.431683,digipa-low-impact -Fred Ludekens,0.431662,digipa-low-impact -David Begbie,0.4316218,digipa-low-impact -Ai Xuan,0.43150818,fareast -Felix-Kelly,0.43132153,digipa-low-impact -Antonín Chittussi,0.431248,digipa-low-impact -Ammi Phillips,0.43095884,digipa-low-impact -Elke Vogelsang,0.43092483,digipa-low-impact -Fathi Hassan,0.43090487,digipa-low-impact -Angela Sung,0.391746,fareast -Clément Serveau,0.43050706,digipa-low-impact -Dong Yuan,0.4303865,fareast -Hew Lorimer,0.43035403,digipa-low-impact -David Finch,0.29487437,digipa-low-impact -Bill Durgin,0.4300932,digipa-low-impact -Alexander Robertson,0.4300743,digipa-low-impact diff --git a/extensions-builtin/roll-artist/scripts/roll-artist.py b/extensions-builtin/roll-artist/scripts/roll-artist.py deleted file mode 100644 index c3bc1fd0..00000000 --- a/extensions-builtin/roll-artist/scripts/roll-artist.py +++ /dev/null @@ -1,50 +0,0 @@ -import random - -from modules import script_callbacks, shared -import gradio as gr - -art_symbol = '\U0001f3a8' # 🎨 -global_prompt = None -related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" } - - -def roll_artist(prompt): - allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories]) - artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) - - return prompt + ", " + artist.name if prompt != '' else artist.name - - -def add_roll_button(prompt): - roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) - - roll.click( - fn=roll_artist, - _js="update_txt2img_tokens", - inputs=[ - prompt, - ], - outputs=[ - prompt, - ] - ) - - -def after_component(component, **kwargs): - global global_prompt - - elem_id = kwargs.get('elem_id', None) - if elem_id not in related_ids: - return - - if elem_id == "txt2img_prompt": - global_prompt = component - elif elem_id == "txt2img_clear_prompt": - add_roll_button(global_prompt) - elif elem_id == "img2img_prompt": - global_prompt = component - elif elem_id == "img2img_clear_prompt": - add_roll_button(global_prompt) - - -script_callbacks.on_after_component(after_component) diff --git a/javascript/hints.js b/javascript/hints.js index f4079f96..ef410fba 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -14,7 +14,6 @@ titles = { "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", - "\u{1f3a8}": "Add a random artist to the prompt.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", "\u{1f4be}": "Save style", diff --git a/modules/api/api.py b/modules/api/api.py index 2c371e6e..f2e9e884 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -126,8 +126,6 @@ class Api: self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) - self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) @@ -390,12 +388,6 @@ class Api: return styleList - def get_artists_categories(self): - return shared.artist_db.cats - - def get_artists(self): - return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] - def get_embeddings(self): db = sd_hijack.model_hijack.embedding_db diff --git a/modules/artists.py b/modules/artists.py deleted file mode 100644 index 3612758b..00000000 --- a/modules/artists.py +++ /dev/null @@ -1,25 +0,0 @@ -import os.path -import csv -from collections import namedtuple - -Artist = namedtuple("Artist", ['name', 'weight', 'category']) - - -class ArtistsDatabase: - def __init__(self, filename): - self.cats = set() - self.artists = [] - - if not os.path.exists(filename): - return - - with open(filename, "r", newline='', encoding="utf8") as file: - reader = csv.DictReader(file) - - for row in reader: - artist = Artist(row["artist"], float(row["score"]), row["category"]) - self.artists.append(artist) - self.cats.add(artist.category) - - def categories(self): - return sorted(self.cats) diff --git a/modules/interrogate.py b/modules/interrogate.py index 738d8ff7..19938cbb 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -5,12 +5,13 @@ from collections import namedtuple import re import torch +import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram, modelloader +from modules import devices, paths, lowvram, modelloader, errors blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +def download_default_clip_interrogate_categories(content_dir): + print("Downloading CLIP categories...") + + tmpdir = content_dir + "_tmp" + try: + os.makedirs(tmpdir) + + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt")) + + os.rename(tmpdir, content_dir) + + except Exception as e: + errors.display(e, "downloading default CLIP interrogate categories") + finally: + if os.path.exists(tmpdir): + os.remove(tmpdir) + + class InterrogateModels: blip_model = None clip_model = None clip_preprocess = None - categories = None dtype = None running_on_cpu = None def __init__(self, content_dir): - self.categories = [] + self.loaded_categories = None + self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") - if os.path.exists(content_dir): - for filename in os.listdir(content_dir): + def categories(self): + if self.loaded_categories is not None: + return self.loaded_categories + + self.loaded_categories = [] + + if not os.path.exists(self.content_dir): + download_default_clip_interrogate_categories(self.content_dir) + + if os.path.exists(self.content_dir): + for filename in os.listdir(self.content_dir): m = re_topn.search(filename) topn = 1 if m is None else int(m.group(1)) - with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file: + with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.categories.append(Category(name=filename, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) + + return self.loaded_categories def load_blip_model(self): import models.blip @@ -139,7 +172,6 @@ class InterrogateModels: shared.state.begin() shared.state.job = 'interrogate' try: - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() devices.torch_gc() @@ -159,12 +191,7 @@ class InterrogateModels: image_features /= image_features.norm(dim=-1, keepdim=True) - if shared.opts.interrogate_use_builtin_artists: - artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] - - res += ", " + artist[0] - - for name, topn, items in self.categories: + for name, topn, items in self.categories(): matches = self.rank(image_features, items, top_count=topn) for match, score in matches: if shared.opts.interrogate_return_ranks: diff --git a/modules/shared.py b/modules/shared.py index c0e11f18..72fb1934 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -9,7 +9,6 @@ from PIL import Image import gradio as gr import tqdm -import modules.artists import modules.interrogate import modules.memmon import modules.styles @@ -254,8 +253,6 @@ class State: state = State() state.server_start = time.time() -artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) - styles_filename = cmd_opts.styles_file prompt_styles = modules.styles.StyleDatabase(styles_filename) @@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), - "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) options_templates.update(options_section(('compatibility', "Compatibility"), { @@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), - "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), diff --git a/modules/ui.py b/modules/ui.py index d23b2b8e..164e0e93 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di left, _ = os.path.splitext(filename) print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a')) - return [gr_show(True), None] + return [gr.update(), None] def interrogate(image): prompt = shared.interrogator.interrogate(image.convert("RGB")) - return gr_show(True) if prompt is None else prompt + return gr.update() if prompt is None else prompt def interrogate_deepbooru(image): prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt + return gr.update() if prompt is None else prompt def create_seed_inputs(target_interface): @@ -1039,19 +1039,18 @@ def create_ui(): init_img_inpaint, ], outputs=[img2img_prompt, dummy_component], - show_progress=False, ) img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) img2img_interrogate.click( - fn=lambda *args : process_interrogate(interrogate, *args), + fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, ) img2img_deepbooru.click( - fn=lambda *args : process_interrogate(interrogate_deepbooru, *args), + fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), **interrogate_args, ) -- cgit v1.2.3 From 3262e825cc542ff634e6ba2e3a162eafdc6c1bba Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 21 Jan 2023 17:42:04 +0900 Subject: add --xformers-flash-attention option & impl --- modules/sd_hijack_optimizations.py | 26 ++++++++++++++++++++++++-- modules/shared.py | 1 + 2 files changed, 25 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 4fa54329..9967359b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -290,7 +290,19 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + # print('xformers_attention_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -365,7 +377,17 @@ def xformers_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v) + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)): + # print('xformers_attnblock_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + out = xformers.ops.memory_efficient_attention(q, k, v, op=op) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 72fb1934..23328adf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -57,6 +57,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") +parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") -- cgit v1.2.3 From 63b824376c49013880ff44c260ea426e2899511e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 18:47:54 +0300 Subject: add --gradio-queue option to enable gradio queue --- modules/shared.py | 2 ++ webui.py | 3 +++ 2 files changed, 5 insertions(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 72fb1934..52bbb807 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -100,6 +100,8 @@ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS o parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) +parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button") + script_loading.preload_extensions(extensions.extensions_dir, parser) script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) diff --git a/webui.py b/webui.py index 88d04840..d235da74 100644 --- a/webui.py +++ b/webui.py @@ -169,6 +169,9 @@ def webui(): shared.demo = modules.ui.create_ui() + if cmd_opts.gradio_queue: + shared.demo.queue(64) + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, -- cgit v1.2.3 From fe7a623e6b7e04bab2cfc96e8fd6cf49b48daee1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 00:02:41 +0300 Subject: add a slider for default value of added extra networks --- extensions-builtin/Lora/ui_extra_networks_lora.py | 3 ++- javascript/hints.js | 3 ++- modules/shared.py | 5 +++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_hypernets.py | 3 ++- 5 files changed, 10 insertions(+), 6 deletions(-) (limited to 'modules/shared.py') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 4406f8a0..54a80d36 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,3 +1,4 @@ +import json import os import lora @@ -26,7 +27,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, - "prompt": f"", + "prompt": json.dumps(f""), "local_preview": path + ".png", } diff --git a/javascript/hints.js b/javascript/hints.js index ef410fba..2aec71a9 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -107,7 +107,8 @@ titles = { "Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.", "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.", "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", - "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders." + "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", + "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it." } diff --git a/modules/shared.py b/modules/shared.py index 52bbb807..00a1d64c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -398,7 +398,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), + "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), @@ -406,7 +406,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), - 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e2e060c8..4c88193f 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -54,7 +54,7 @@ class ExtraNetworksPage: args = { "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', - "prompt": json.dumps(item["prompt"]), + "prompt": item["prompt"], "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 312dbaf0..65d000cf 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -1,3 +1,4 @@ +import json import os from modules import shared, ui_extra_networks @@ -25,7 +26,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, - "prompt": f"", + "prompt": json.dumps(f""), "local_preview": path + ".png", } -- cgit v1.2.3 From 2621566153920eb70bfa439f3d7c126ee8d36ec8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 08:05:21 +0300 Subject: attention ctrl+up/down enhancements --- javascript/edit-attention.js | 110 ++++++++++++++++++++++++++----------------- modules/shared.py | 8 ++-- 2 files changed, 71 insertions(+), 47 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index cec6a530..619bb1fa 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -1,74 +1,96 @@ -addEventListener('keydown', (event) => { +function keyupEditAttention(event){ let target = event.originalTarget || event.composedPath()[0]; if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return; if (! (event.metaKey || event.ctrlKey)) return; - - let plus = "ArrowUp" - let minus = "ArrowDown" - if (event.key != plus && event.key != minus) return; + let isPlus = event.key == "ArrowUp" + let isMinus = event.key == "ArrowDown" + if (!isPlus && !isMinus) return; let selectionStart = target.selectionStart; let selectionEnd = target.selectionEnd; - // If the user hasn't selected anything, let's select their current parenthesis block - if (selectionStart === selectionEnd) { + let text = target.value; + + function selectCurrentParenthesisBlock(OPEN, CLOSE){ + if (selectionStart !== selectionEnd) return false; + // Find opening parenthesis around current cursor - const before = target.value.substring(0, selectionStart); - let beforeParen = before.lastIndexOf("("); - if (beforeParen == -1) return; - let beforeParenClose = before.lastIndexOf(")"); + const before = text.substring(0, selectionStart); + let beforeParen = before.lastIndexOf(OPEN); + if (beforeParen == -1) return false; + let beforeParenClose = before.lastIndexOf(CLOSE); while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { - beforeParen = before.lastIndexOf("(", beforeParen - 1); - beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1); + beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); + beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); } // Find closing parenthesis around current cursor - const after = target.value.substring(selectionStart); - let afterParen = after.indexOf(")"); - if (afterParen == -1) return; - let afterParenOpen = after.indexOf("("); + const after = text.substring(selectionStart); + let afterParen = after.indexOf(CLOSE); + if (afterParen == -1) return false; + let afterParenOpen = after.indexOf(OPEN); while (afterParenOpen !== -1 && afterParen > afterParenOpen) { - afterParen = after.indexOf(")", afterParen + 1); - afterParenOpen = after.indexOf("(", afterParenOpen + 1); + afterParen = after.indexOf(CLOSE, afterParen + 1); + afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); } - if (beforeParen === -1 || afterParen === -1) return; + if (beforeParen === -1 || afterParen === -1) return false; // Set the selection to the text between the parenthesis - const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen); + const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); const lastColon = parenContent.lastIndexOf(":"); selectionStart = beforeParen + 1; selectionEnd = selectionStart + lastColon; target.setSelectionRange(selectionStart, selectionEnd); - } + return true; + } + + // If the user hasn't selected anything, let's select their current parenthesis block + if(! selectCurrentParenthesisBlock('<', '>')){ + selectCurrentParenthesisBlock('(', ')') + } event.preventDefault(); - if (selectionStart == 0 || target.value[selectionStart - 1] != "(") { - target.value = target.value.slice(0, selectionStart) + - "(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" + - target.value.slice(selectionEnd); + closeCharacter = ')' + delta = opts.keyedit_precision_attention + + if (selectionStart > 0 && text[selectionStart - 1] == '<'){ + closeCharacter = '>' + delta = opts.keyedit_precision_extra + } else if (selectionStart == 0 || text[selectionStart - 1] != "(") { + + // do not include spaces at the end + while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){ + selectionEnd -= 1; + } + if(selectionStart == selectionEnd){ + return + } - target.focus(); - target.selectionStart = selectionStart + 1; - target.selectionEnd = selectionEnd + 1; + text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); - } else { - end = target.value.slice(selectionEnd + 1).indexOf(")") + 1; - weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end)); - if (isNaN(weight)) return; - if (event.key == minus) weight -= 0.1; - if (event.key == plus) weight += 0.1; + selectionStart += 1; + selectionEnd += 1; + } - weight = parseFloat(weight.toPrecision(12)); + end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); + if (isNaN(weight)) return; - target.value = target.value.slice(0, selectionEnd + 1) + - weight + - target.value.slice(selectionEnd + 1 + end - 1); + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if(String(weight).length == 1) weight += ".0" - target.focus(); - target.selectionStart = selectionStart; - target.selectionEnd = selectionEnd; - } + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); + + target.focus(); + target.value = text; + target.selectionStart = selectionStart; + target.selectionEnd = selectionEnd; updateInput(target) -}); +} + +addEventListener('keydown', (event) => { + keyupEditAttention(event); +}); \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index 00a1d64c..d68ac296 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -444,9 +444,11 @@ options_templates.update(options_section(('ui', "User interface"), { "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), - 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), - 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), - 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), + "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), + "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), + "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) options_templates.update(options_section(('ui', "Live previews"), { -- cgit v1.2.3 From 35419b274614984e2b511a6ad34f37e41481c809 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 11:00:05 +0300 Subject: add an option to reorder tabs for extra networks --- javascript/hints.js | 3 ++- modules/shared.py | 1 + modules/ui_extra_networks.py | 18 +++++++++++++++++- 3 files changed, 20 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/hints.js b/javascript/hints.js index 833543f0..3cf10e20 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -109,7 +109,8 @@ titles = { "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", - "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights." + "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", + "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited." } diff --git a/modules/shared.py b/modules/shared.py index d68ac296..cd78e50a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -448,6 +448,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), + "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 4c88193f..285c8ffe 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -79,6 +79,22 @@ class ExtraNetworksUi: self.tabname = None +def pages_in_preferred_order(pages): + tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")] + + def tab_name_score(name): + name = name.lower() + for i, possible_match in enumerate(tab_order): + if possible_match in name: + return i + + return len(pages) + + tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)} + + return sorted(pages, key=lambda x: tab_scores[x.name]) + + def create_ui(container, button, tabname): ui = ExtraNetworksUi() ui.pages = [] @@ -86,7 +102,7 @@ def create_ui(container, button, tabname): ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: - for page in ui.stored_extra_pages: + for page in pages_in_preferred_order(ui.stored_extra_pages): with gr.Tab(page.title): page_elem = gr.HTML(page.create_html(ui.tabname)) ui.pages.append(page_elem) -- cgit v1.2.3 From 66eef11ce7f3db108225668c573cb4a763a43fb3 Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sat, 21 Jan 2023 18:27:57 -0300 Subject: feat(extra-networks): add default view setting --- modules/shared.py | 4 ++++ modules/ui_extra_networks.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index cd78e50a..e9548864 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -430,6 +430,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"), })) +options_templates.update(options_section(('extra_networks', "Extra Networks"), { + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }), +})) + options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index ce4801b5..179ba47a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,7 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname, view = 'cards'): + def create_html(self, tabname, view=shared.opts.extra_networks_default_view): items_html = '' for item in self.list_items(): @@ -111,7 +111,7 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") - ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value='cards') + ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value=lambda: shared.opts.extra_networks_default_view) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -119,7 +119,7 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(view='cards'): + def refresh(view): res = [] for pg in ui.stored_extra_pages: @@ -142,7 +142,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename, view='cards'): + def save_preview(index, images, filename, view): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] -- cgit v1.2.3 From b5230197a69d36a79fdc4919c59a03e00e872dd3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 09:24:43 +0300 Subject: rework extras tab to use script system --- javascript/ui.js | 5 - modules/api/api.py | 13 +- modules/postprocessing.py | 236 ++++++++----------------------- modules/scripts.py | 28 ++-- modules/scripts_postprocessing.py | 147 +++++++++++++++++++ modules/shared.py | 5 + modules/ui.py | 265 +---------------------------------- modules/ui_common.py | 202 ++++++++++++++++++++++++++ modules/ui_components.py | 6 - modules/ui_postprocessing.py | 57 ++++++++ scripts/postprocessing_codeformer.py | 36 +++++ scripts/postprocessing_gfpgan.py | 33 +++++ scripts/postprocessing_upscale.py | 106 ++++++++++++++ 13 files changed, 675 insertions(+), 464 deletions(-) create mode 100644 modules/scripts_postprocessing.py create mode 100644 modules/ui_common.py create mode 100644 modules/ui_postprocessing.py create mode 100644 scripts/postprocessing_codeformer.py create mode 100644 scripts/postprocessing_gfpgan.py create mode 100644 scripts/postprocessing_upscale.py (limited to 'modules/shared.py') diff --git a/javascript/ui.js b/javascript/ui.js index 77256e15..ba72623c 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -104,11 +104,6 @@ function create_tab_index_args(tabId, args){ return res } -function get_extras_tab_index(){ - const [,,...args] = [...arguments] - return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args] -} - function get_img2img_tab_index() { let res = args_to_array(arguments) res.splice(-2) diff --git a/modules/api/api.py b/modules/api/api.py index f2e9e884..5d60fc0a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.extras import run_extras from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork @@ -45,10 +44,8 @@ def validate_sampler_name(name): def setUpscalers(req: dict): reqDict = vars(req) - reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) - reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) + reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict def decode_base64_to_image(encoding): @@ -244,7 +241,7 @@ class Api: reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) @@ -260,7 +257,7 @@ class Api: reqDict.pop('imageList') with self.queue_lock: - result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index cb85720b..8514fea7 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -1,219 +1,103 @@ -from __future__ import annotations import os -import numpy as np from PIL import Image -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import shared, images, devices, ui_components +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste from modules.shared import opts -import modules.gfpgan_model -import modules.codeformer_model - - -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) -def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): +def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin() shared.state.job = 'extras' - imageArr = [] - # Also keep track of original file names - imageNameArr = [] + image_data = [] + image_names = [] outputs = [] if extras_mode == 1: - #convert file to pillow image for img in image_folder: image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) + image_data.append(image) + image_names.append(os.path.splitext(img.orig_name)[0]) elif extras_mode == 2: assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + assert input_dir, 'input directory not selected' - if input_dir == '': - return outputs, "Please select an input directory.", '' image_list = shared.listfiles(input_dir) - for img in image_list: + for filename in image_list: try: - image = Image.open(img) + image = Image.open(filename) except Exception: continue - imageArr.append(image) - imageNameArr.append(img) + image_data.append(image) + image_names.append(filename) else: - imageArr.append(image) - imageNameArr.append(None) + assert image, 'image not selected' + + image_data.append(image) + image_names.append(None) if extras_mode == 2 and output_dir != '': outpath = output_dir else: outpath = opts.outdir_samples or opts.outdir_extras_samples - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - + infotext = '' + + for image, name in zip(image_data, image_names): + shared.state.textinfo = name + existing_pnginfo = image.info or {} - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) + pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] + scripts.scripts_postproc.run(pp, args) + + if opts.use_original_name_batch and name is not None: + basename = os.path.splitext(os.path.basename(name))[0] else: basename = '' - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info + infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + if opts.enable_pnginfo: + pp.image.info = existing_pnginfo + pp.image.info["postprocessing"] = infotext - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) + if save_output: + images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) - if extras_mode != 2 or show_extras_results : - outputs.append(image) + if extras_mode != 2 or show_extras_results: + outputs.append(pp.image) devices.torch_gc() - return outputs, ui_components.plaintext_to_html(info), '' - - -def clear_cache(): - cached_images.clear() - + return outputs, ui_common.plaintext_to_html(infotext), '' + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + """old handler for API""" + + args = scripts.scripts_postproc.create_args_for_run({ + "Upscale": { + "upscale_mode": resize_mode, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + }, + "GFPGAN": { + "gfpgan_visibility": gfpgan_visibility, + }, + "CodeFormer": { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + }, + }) + + return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) diff --git a/modules/scripts.py b/modules/scripts.py index 4ffc369b..03907a63 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ from collections import namedtuple import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks, extensions, script_loading +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() @@ -150,8 +150,10 @@ def basedir(): return current_basedir -scripts_data = [] ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) + +scripts_data = [] +postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) @@ -190,23 +192,31 @@ def list_files_with_name(filename): def load_scripts(): global current_basedir scripts_data.clear() + postprocessing_scripts_data.clear() script_callbacks.clear_callbacks() scripts_list = list_scripts("scripts", ".py") syspath = sys.path + def register_scripts_from_module(module): + for key, script_class in module.__dict__.items(): + if type(script_class) != type: + continue + + if issubclass(script_class, Script): + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): + postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + for scriptfile in sorted(scripts_list): try: if scriptfile.basedir != paths.script_path: sys.path = [scriptfile.basedir] + sys.path current_basedir = scriptfile.basedir - module = script_loading.load_module(scriptfile.path) - - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + script_module = script_loading.load_module(scriptfile.path) + register_scripts_from_module(script_module) except Exception: print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) @@ -413,6 +423,7 @@ class ScriptRunner: scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() +scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() scripts_current: ScriptRunner = None @@ -423,12 +434,13 @@ def reload_script_body_only(): def reload_scripts(): - global scripts_txt2img, scripts_img2img + global scripts_txt2img, scripts_img2img, scripts_postproc load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() def IOComponent_init(self, *args, **kwargs): diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py new file mode 100644 index 00000000..25de02d0 --- /dev/null +++ b/modules/scripts_postprocessing.py @@ -0,0 +1,147 @@ +import os +import gradio as gr + +from modules import errors, shared + + +class PostprocessedImage: + def __init__(self, image): + self.image = image + self.info = {} + + +class ScriptPostprocessing: + filename = None + controls = None + args_from = None + args_to = None + + order = 1000 + """scripts will be ordred by this value in postprocessing UI""" + + name = None + """this function should return the title of the script.""" + + group = None + """A gr.Group component that has all script's UI inside it""" + + def ui(self): + """ + This function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be a dictionary that maps parameter names to components used in processing. + Values of those components will be passed to process() function. + """ + + pass + + def process(self, pp: PostprocessedImage, **args): + """ + This function is called to postprocess the image. + args contains a dictionary with all values returned by components from ui() + """ + + pass + + def image_changed(self): + pass + + +def wrap_call(func, filename, funcname, *args, default=None, **kwargs): + try: + res = func(*args, **kwargs) + return res + except Exception as e: + errors.display(e, f"calling {filename}/{funcname}") + + return default + + +class ScriptPostprocessingRunner: + def __init__(self): + self.scripts = None + self.ui_created = False + + def initialize_scripts(self, scripts_data): + self.scripts = [] + + for script_class, path, basedir, script_module in scripts_data: + script: ScriptPostprocessing = script_class() + script.filename = path + + self.scripts.append(script) + + def create_script_ui(self, script, inputs): + script.args_from = len(inputs) + script.args_to = len(inputs) + + script.controls = wrap_call(script.ui, script.filename, "ui") + + for control in script.controls.values(): + control.custom_script_source = os.path.basename(script.filename) + + inputs += list(script.controls.values()) + script.args_to = len(inputs) + + def scripts_in_preferred_order(self): + if self.scripts is None: + import modules.scripts + self.initialize_scripts(modules.scripts.postprocessing_scripts_data) + + scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + + def script_score(name): + name = name.lower() + for i, possible_match in enumerate(scripts_order): + if possible_match in name: + return i + + return len(self.scripts) + + script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)} + + return sorted(self.scripts, key=lambda x: script_scores[x.name]) + + def setup_ui(self): + inputs = [] + + for script in self.scripts_in_preferred_order(): + with gr.Box() as group: + self.create_script_ui(script, inputs) + + script.group = group + + self.ui_created = True + return inputs + + def run(self, pp: PostprocessedImage, args): + for script in self.scripts_in_preferred_order(): + shared.state.job = script.name + + script_args = args[script.args_from:script.args_to] + + process_args = {} + for (name, component), value in zip(script.controls.items(), script_args): + process_args[name] = value + + script.process(pp, **process_args) + + def create_args_for_run(self, scripts_args): + if not self.ui_created: + with gr.Blocks(analytics_enabled=False): + self.setup_ui() + + scripts = self.scripts_in_preferred_order() + args = [None] * max([x.args_to for x in scripts]) + + for script in scripts: + script_args_dict = scripts_args.get(script.name, None) + if script_args_dict is not None: + + for i, name in enumerate(script.controls): + args[script.args_from + i] = script_args_dict.get(name, None) + + return args + + def image_changed(self): + for script in self.scripts_in_preferred_order(): + script.image_changed() diff --git a/modules/shared.py b/modules/shared.py index cd78e50a..cb73bf31 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -474,6 +474,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), })) +options_templates.update(options_section(('postprocessing', "Postprocessing"), { + 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), +})) + options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable those extensions"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), diff --git a/modules/ui.py b/modules/ui.py index 4116e167..8cb8e613 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -86,7 +86,6 @@ css_hide_progressbar = """ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ reuse_symbol = '\u267b\ufe0f' # ♻️ paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 @@ -95,7 +94,7 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): - return ui_components.plaintext_to_html(text) + return ui_common.plaintext_to_html(text) def send_gradio_gallery_to_image(x): @@ -103,70 +102,6 @@ def send_gradio_gallery_to_image(x): return None return image_from_url_text(x[0]) -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - def visit(x, func, path=""): if hasattr(x, 'children'): for c in x.children: @@ -444,19 +379,6 @@ def apply_setting(key, value): opts.save(shared.config_filename) return getattr(opts, key) - -def update_generation_info(generation_info, html_info, img_index): - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info, gr.update() - return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info, gr.update() - - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -477,107 +399,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel', elem_id=f"{tabname}_results"): - with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[generation_info, html_info, html_info], - outputs=[html_info, html_info], - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ], - show_progress=False, - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + return ui_common.create_output_panel(tabname, outdir) def create_sampler_and_steps_selection(choices, tabname): @@ -1106,86 +928,7 @@ def create_ui(): modules.scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='compact'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=postprocessing.clear_cache, - inputs=[], outputs=[] - ) + ui_postprocessing.create_ui() with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): diff --git a/modules/ui_common.py b/modules/ui_common.py new file mode 100644 index 00000000..8ce75b8c --- /dev/null +++ b/modules/ui_common.py @@ -0,0 +1,202 @@ +import json +import html +import os +import platform +import sys + +import gradio as gr +import scipy as sp + +from modules import call_queue, shared +from modules.generation_parameters_copypaste import image_from_url_text +import modules.images + +folder_symbol = '\U0001f4c2' # 📂 + + +def update_generation_info(generation_info, html_info, img_index): + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info, gr.update() + + +def plaintext_to_html(text): + text = "

    " + "
    \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

    " + return text + + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = shared.opts.outdir_save + save_to_dirs = shared.opts.use_save_to_dirs_for_ui + extension: str = shared.opts.samples_format + start_index = 0 + + if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(shared.opts.outdir_save, exist_ok=True) + + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def create_output_panel(tabname, outdir): + from modules import shared + import modules.generation_parameters_copypaste as parameters_copypaste + + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(shared.opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], + ) + + save.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ], + show_progress=False, + ) + + save_zip.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log diff --git a/modules/ui_components.py b/modules/ui_components.py index 989cc87b..9aec3097 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,5 +1,3 @@ -import html - import gradio as gr @@ -50,7 +48,3 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" - -def plaintext_to_html(text): - text = "

    " + "
    \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

    " - return text diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py new file mode 100644 index 00000000..b418d955 --- /dev/null +++ b/modules/ui_postprocessing.py @@ -0,0 +1,57 @@ +import gradio as gr +from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue +import modules.generation_parameters_copypaste as parameters_copypaste + + +def create_ui(): + tab_index = gr.State(value=0) + + with gr.Row().style(equal_height=False, variant='compact'): + with gr.Column(variant='compact'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single: + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + script_inputs = scripts.scripts_postproc.setup_ui() + + with gr.Column(): + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + + tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) + tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) + tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) + + submit.click( + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + inputs=[ + tab_index, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + *script_inputs + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=scripts.scripts_postproc.image_changed, + inputs=[], outputs=[] + ) diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py new file mode 100644 index 00000000..a7d80d40 --- /dev/null +++ b/scripts/postprocessing_codeformer.py @@ -0,0 +1,36 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, codeformer_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): + name = "CodeFormer" + order = 3000 + + def ui(self): + with FormRow(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") + + return { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0: + return + + restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(pp.image, res, codeformer_visibility) + + pp.image = res + pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) + pp.info["CodeFormer weight"] = round(codeformer_weight, 3) diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py new file mode 100644 index 00000000..d854f3f7 --- /dev/null +++ b/scripts/postprocessing_gfpgan.py @@ -0,0 +1,33 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, gfpgan_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): + name = "GFPGAN" + order = 2000 + + def ui(self): + with FormRow(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + + return { + "gfpgan_visibility": gfpgan_visibility, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): + if gfpgan_visibility == 0: + return + + restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(pp.image, res, gfpgan_visibility) + + pp.image = res + pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py new file mode 100644 index 00000000..095d29b2 --- /dev/null +++ b/scripts/postprocessing_upscale.py @@ -0,0 +1,106 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, shared +import gradio as gr + +from modules.ui_components import FormRow + + +upscale_cache = {} + + +class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): + name = "Upscale" + order = 1000 + + def ui(self): + selected_tab = gr.State(value=0) + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with FormRow(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with FormRow(): + extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + + tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) + tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + + return { + "upscale_mode": selected_tab, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + } + + def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): + if upscale_mode == 1: + upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) + info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" + else: + info["Postprocess upscale by"] = upscale_by + + cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + cached_image = upscale_cache.pop(cache_key, None) + + if cached_image is not None: + image = cached_image + else: + image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path) + + upscale_cache[cache_key] = image + if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache: + upscale_cache.pop(next(iter(upscale_cache), None), None) + + if upscale_mode == 1 and upscale_crop: + cropped = Image.new("RGB", (upscale_to_width, upscale_to_height)) + cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2)) + image = cropped + info["Postprocess crop to"] = f"{image.width}x{image.height}" + + return image + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + if upscaler_1_name == "None": + upscaler_1_name = None + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None) + assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}' + + if not upscaler1: + return + + if upscaler_2_name == "None": + upscaler_2_name = None + + upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) + assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' + + upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + pp.info[f"Postprocess upscaler"] = upscaler1.name + + if upscaler2 and upscaler_2_visibility > 0: + second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) + + pp.info[f"Postprocess upscaler 2"] = upscaler2.name + + pp.image = upscaled_image + + def image_changed(self): + upscale_cache.clear() -- cgit v1.2.3 From 925dd09c91e7338aef72e4ec99d67b8b57280215 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 09:03:17 -0500 Subject: improve interrogate --- modules/interrogate.py | 29 +++++++++++++++++------------ modules/shared.py | 1 + 2 files changed, 18 insertions(+), 12 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/interrogate.py b/modules/interrogate.py index 19938cbb..1d1ac572 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -20,6 +20,7 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +category_types = ["artists", "flavors", "mediums", "movements"] def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") @@ -27,12 +28,8 @@ def download_default_clip_interrogate_categories(content_dir): tmpdir = content_dir + "_tmp" try: os.makedirs(tmpdir) - - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt")) - + for category_type in category_types: + torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt")) os.rename(tmpdir, content_dir) except Exception as e: @@ -51,12 +48,13 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None + self.selected_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None: - return self.loaded_categories + if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories: + return self.loaded_categories self.loaded_categories = [] @@ -64,14 +62,19 @@ class InterrogateModels: download_default_clip_interrogate_categories(self.content_dir) if os.path.exists(self.content_dir): - for filename in os.listdir(self.content_dir): + self.selected_categories = shared.opts.interrogate_clip_categories + for category_type in category_types: + if 'all' not in self.selected_categories and category_type not in self.selected_categories: + continue + filename = os.path.join(self.content_dir, f"{category_type}.txt") + if not os.path.isfile(filename): + continue m = re_topn.search(filename) topn = 1 if m is None else int(m.group(1)) - - with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file: + with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines)) return self.loaded_categories @@ -139,6 +142,8 @@ class InterrogateModels: def rank(self, image_features, text_array, top_count=1): import clip + devices.torch_gc() + if shared.opts.interrogate_clip_dict_limit != 0: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] diff --git a/modules/shared.py b/modules/shared.py index a644c0be..63b236c5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), + "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), -- cgit v1.2.3 From 04a561c11c9bf9a00d7f9b50ca3f7962aa59ba6e Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 12:29:23 -0500 Subject: add option to skip interrogate categories --- modules/interrogate.py | 32 ++++++++++++++++++-------------- modules/shared.py | 2 +- 2 files changed, 19 insertions(+), 15 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/interrogate.py b/modules/interrogate.py index 1d1ac572..c252b148 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -2,6 +2,7 @@ import os import sys import traceback from collections import namedtuple +from pathlib import Path import re import torch @@ -20,12 +21,16 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") -category_types = ["artists", "flavors", "mediums", "movements"] +def category_types(): + return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')] + def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") tmpdir = content_dir + "_tmp" + category_types = ["artists", "flavors", "mediums", "movements"] + try: os.makedirs(tmpdir) for category_type in category_types: @@ -48,33 +53,32 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None - self.selected_categories = [] + self.skip_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories: + if not os.path.exists(self.content_dir): + download_default_clip_interrogate_categories(self.content_dir) + + if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories: return self.loaded_categories self.loaded_categories = [] - if not os.path.exists(self.content_dir): - download_default_clip_interrogate_categories(self.content_dir) - if os.path.exists(self.content_dir): - self.selected_categories = shared.opts.interrogate_clip_categories - for category_type in category_types: - if 'all' not in self.selected_categories and category_type not in self.selected_categories: - continue - filename = os.path.join(self.content_dir, f"{category_type}.txt") - if not os.path.isfile(filename): + self.skip_categories = shared.opts.interrogate_clip_skip_categories + category_types = [] + for filename in Path(self.content_dir).glob('*.txt'): + category_types.append(filename.stem) + if filename.stem in self.skip_categories: continue - m = re_topn.search(filename) + m = re_topn.search(filename.stem) topn = 1 if m is None else int(m.group(1)) with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines)) return self.loaded_categories diff --git a/modules/shared.py b/modules/shared.py index d7a18f6a..5f713bee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -424,7 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), - "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}), + "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), -- cgit v1.2.3 From 84d9ce30cb427759547bc7876ed80ab91787d175 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 24 Jan 2023 23:51:45 -0500 Subject: Add option for float32 sampling with float16 UNet This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half(). --- README.md | 1 + modules/deepbooru_model.py | 4 +++- modules/devices.py | 2 ++ modules/processing.py | 15 ++++++++------- modules/sd_hijack_unet.py | 29 +++++++++++++++++++++++++++++ modules/sd_hijack_utils.py | 28 ++++++++++++++++++++++++++++ modules/sd_models.py | 10 ++++++++++ modules/shared.py | 1 + 8 files changed, 82 insertions(+), 8 deletions(-) create mode 100644 modules/sd_hijack_utils.py (limited to 'modules/shared.py') diff --git a/README.md b/README.md index 9c0cd1ef..a5611671 100644 --- a/README.md +++ b/README.md @@ -157,4 +157,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru - Security advice - RyotaK - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. +- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) - (You) diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py index edd40c81..83d2ff09 100644 --- a/modules/deepbooru_model.py +++ b/modules/deepbooru_model.py @@ -2,6 +2,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from modules import devices + # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more @@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module): t_358, = inputs t_359 = t_358.permute(*[0, 3, 1, 2]) t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0) - t_360 = self.n_Conv_0(t_359_padded) + t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded) t_361 = F.relu(t_360) t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf')) t_362 = self.n_MaxPool_0(t_361) diff --git a/modules/devices.py b/modules/devices.py index 524ec7af..0981ef80 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -79,6 +79,8 @@ cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 +dtype_unet = torch.float16 +unet_needs_upcast = False def randn(seed, shape): diff --git a/modules/processing.py b/modules/processing.py index bc541e2f..2d186ba0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,7 +172,8 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) + conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -203,7 +204,7 @@ class StableDiffusionProcessing: # Create another latent image, this time with a masked version of the original input. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. - conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype) + conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype) conditioning_image = torch.lerp( source_image, source_image * (1.0 - conditioning_mask), @@ -211,7 +212,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -225,10 +226,10 @@ class StableDiffusionProcessing: # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # identify itself with a field common to all models. The conditioning_key is also hybrid. if isinstance(self.sd_model, LatentDepth2ImageDiffusion): - return self.depth2img_image_conditioning(source_image) + return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -610,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(): + with devices.autocast(disable=devices.unet_needs_upcast): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] @@ -988,7 +989,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(shared.device) + image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 18daf8c1..88c94e54 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,4 +1,8 @@ import torch +from packaging import version + +from modules import devices +from modules.sd_hijack_utils import CondFunc class TorchHijackForUnet: @@ -28,3 +32,28 @@ class TorchHijackForUnet: th = TorchHijackForUnet() + + +# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling +def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): + for y in cond.keys(): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + with devices.autocast(): + return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + +class GELUHijack(torch.nn.GELU, torch.nn.Module): + def __init__(self, *args, **kwargs): + torch.nn.GELU.__init__(self, *args, **kwargs) + def forward(self, x): + if devices.unet_needs_upcast: + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + else: + return torch.nn.GELU.forward(self, x) + +unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast) +if version.parse(torch.__version__) <= version.parse("1.13.1"): + CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py new file mode 100644 index 00000000..f81b169a --- /dev/null +++ b/modules/sd_hijack_utils.py @@ -0,0 +1,28 @@ +import importlib + +class CondFunc: + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-2, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..7c98991a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): if not shared.cmd_opts.no_half: vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 if shared.cmd_opts.no_half_vae: model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None model.half() model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) @@ -372,6 +380,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True timer = Timer() diff --git a/modules/shared.py b/modules/shared.py index 5f713bee..4ce1209b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -45,6 +45,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") -- cgit v1.2.3 From e3b53fd295aca784253dfc8668ec87b537a72f43 Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 25 Jan 2023 00:23:10 -0500 Subject: Add UI setting for upcasting attention to float32 Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers. In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also. --- modules/devices.py | 6 +- modules/processing.py | 2 +- modules/sd_hijack_optimizations.py | 159 +++++++++++++++++++++++-------------- modules/shared.py | 1 + modules/sub_quadratic_attention.py | 4 +- 5 files changed, 108 insertions(+), 64 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/devices.py b/modules/devices.py index 0981ef80..6b36622c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -108,6 +108,10 @@ def autocast(disable=False): return torch.autocast("cuda") +def without_autocast(disable=False): + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + + class NansException(Exception): pass @@ -125,7 +129,7 @@ def test_for_nans(x, where): message = "A tensor with all NaNs was produced in Unet." if not shared.cmd_opts.no_half: - message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." + message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." diff --git a/modules/processing.py b/modules/processing.py index 2d186ba0..a850082d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -611,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(disable=devices.unet_needs_upcast): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 74452709..c02d954c 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, errors +from modules import shared, errors, devices from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() - s2 = s1.softmax(dim=-1) - del s1 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + del q, k, v - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None): k_in = self.to_k(context_k) v_in = self.to_v(context_v) - k_in *= self.scale - - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = get_available_vram() - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + dtype = q_in.dtype + if shared.opts.upcast_attn: + q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k_in = k_in * self.scale + + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = get_available_vram() + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) * self.scale + k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = einsum_op(q, k, v) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k = k * self.scale + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + r = r.to(dtype) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) # -- End of code from https://github.com/invoke-ai/InvokeAI -- @@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = x.to(dtype) + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) out_proj, dropout = self.to_out @@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ query_chunk_size = q_tokens kv_chunk_size = k_tokens - return efficient_dot_product_attention( - q, - k, - v, - query_chunk_size=q_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min = kv_chunk_size_min, - use_checkpoint=use_checkpoint, - ) + with devices.without_autocast(disable=q.dtype == v.dtype): + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) def get_xformers_flash_attention_op(q, k, v): @@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) + out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x): v = self.v(h_) b, c, h, w = q.shape q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 4ce1209b..6a0b96cb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -410,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 55052815..05595323 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -67,7 +67,7 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) - exp_values = torch.bmm(exp_weights, value) + exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype) max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) @@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking( ) attn_probs = attn_scores.softmax(dim=-1) del attn_scores - hidden_states_slice = torch.bmm(attn_probs, value) + hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype) return hidden_states_slice -- cgit v1.2.3 From 7a14c8ab45da8a681792a6331d48a88dd684a0a9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 26 Jan 2023 23:29:27 +0300 Subject: add an option to enable sections from extras tab in txt2img/img2img fix some style inconsistenices --- modules/processing.py | 7 +++++- modules/scripts.py | 32 ++++++++++++++++++++++---- modules/scripts_auto_postprocessing.py | 42 ++++++++++++++++++++++++++++++++++ modules/scripts_postprocessing.py | 11 ++++++--- modules/shared.py | 15 ++++-------- modules/shared_items.py | 10 ++++++++ modules/ui_components.py | 8 +++++++ scripts/postprocessing_upscale.py | 25 ++++++++++++++++++++ style.css | 6 +---- 9 files changed, 133 insertions(+), 23 deletions(-) create mode 100644 modules/scripts_auto_postprocessing.py create mode 100644 modules/shared_items.py (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index 92894d67..262806a1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -658,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image = Image.fromarray(x_sample) + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image(p, pp) + image = pp.image + if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) diff --git a/modules/scripts.py b/modules/scripts.py index 03907a63..6e9dc0c0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,12 +6,16 @@ from collections import namedtuple import gradio as gr -from modules.processing import StableDiffusionProcessing from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() +class PostprocessImageArgs: + def __init__(self, image): + self.image = image + + class Script: filename = None args_from = None @@ -65,7 +69,7 @@ class Script: args contains all values returned by components from ui() """ - raise NotImplementedError() + pass def process(self, p, *args): """ @@ -100,6 +104,13 @@ class Script: pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -247,11 +258,15 @@ class ScriptRunner: self.infotext_fields = [] def initialize_scripts(self, is_img2img): + from modules import scripts_auto_postprocessing + self.scripts.clear() self.alwayson_scripts.clear() self.selectable_scripts.clear() - for script_class, path, basedir, script_module in scripts_data: + auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() + + for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data: script = script_class() script.filename = path script.is_txt2img = not is_img2img @@ -332,7 +347,7 @@ class ScriptRunner: return inputs - def run(self, p: StableDiffusionProcessing, *args): + def run(self, p, *args): script_index = args[0] if script_index == 0: @@ -386,6 +401,15 @@ class ScriptRunner: print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def postprocess_image(self, p, pp: PostprocessImageArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_image(p, pp, *script_args) + except Exception: + print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def before_component(self, component, **kwargs): for script in self.scripts: try: diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py new file mode 100644 index 00000000..30d6d658 --- /dev/null +++ b/modules/scripts_auto_postprocessing.py @@ -0,0 +1,42 @@ +from modules import scripts, scripts_postprocessing, shared + + +class ScriptPostprocessingForMainUI(scripts.Script): + def __init__(self, script_postproc): + self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc + self.postprocessing_controls = None + + def title(self): + return self.script.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + self.postprocessing_controls = self.script.ui() + return self.postprocessing_controls.values() + + def postprocess_image(self, p, script_pp, *args): + args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)} + + pp = scripts_postprocessing.PostprocessedImage(script_pp.image) + pp.info = {} + self.script.process(pp, **args_dict) + p.extra_generation_params.update(pp.info) + script_pp.image = pp.image + + +def create_auto_preprocessing_script_data(): + from modules import scripts + + res = [] + + for name in shared.opts.postprocessing_enable_in_main_ui: + script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None) + if script is None: + continue + + constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) + res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module)) + + return res diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index 25de02d0..ce0ebb61 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -46,6 +46,8 @@ class ScriptPostprocessing: pass + + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: res = func(*args, **kwargs) @@ -68,6 +70,9 @@ class ScriptPostprocessingRunner: script: ScriptPostprocessing = script_class() script.filename = path + if script.name == "Simple Upscale": + continue + self.scripts.append(script) def create_script_ui(self, script, inputs): @@ -87,12 +92,11 @@ class ScriptPostprocessingRunner: import modules.scripts self.initialize_scripts(modules.scripts.postprocessing_scripts_data) - scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + scripts_order = shared.opts.postprocessing_operation_order def script_score(name): - name = name.lower() for i, possible_match in enumerate(scripts_order): - if possible_match in name: + if possible_match == name: return i return len(self.scripts) @@ -145,3 +149,4 @@ class ScriptPostprocessingRunner: def image_changed(self): for script in self.scripts_in_preferred_order(): script.image_changed() + diff --git a/modules/shared.py b/modules/shared.py index 6a0b96cb..cdeed55d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,8 +13,8 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components -from modules.paths import models_path, script_path, sd_path +from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules.paths import models_path, script_path demo = None @@ -264,12 +264,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] - -def realesrgan_models_names(): - import modules.realesrgan_model - return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] - - class OptionInfo: def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): self.default = default @@ -360,7 +354,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), })) @@ -483,7 +477,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" })) options_templates.update(options_section(('postprocessing', "Postprocessing"), { - 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), })) diff --git a/modules/shared_items.py b/modules/shared_items.py new file mode 100644 index 00000000..b5d480c9 --- /dev/null +++ b/modules/shared_items.py @@ -0,0 +1,10 @@ + + +def realesrgan_models_names(): + import modules.realesrgan_model + return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + +def postprocessing_scripts(): + import modules.scripts + + return modules.scripts.scripts_postproc.scripts \ No newline at end of file diff --git a/modules/ui_components.py b/modules/ui_components.py index 9aec3097..284ca0cf 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + +class DropdownMulti(gr.Dropdown): + """Same as gr.Dropdown but always multiselect""" + def __init__(self, **kwargs): + super().__init__(multiselect=True, **kwargs) + + def get_block_name(self): + return "dropdown" diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 095d29b2..8842bd91 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): def image_changed(self): upscale_cache.clear() + + +class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): + name = "Simple Upscale" + order = 900 + + def ui(self): + with FormRow(): + upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2) + + return { + "upscale_by": upscale_by, + "upscaler_name": upscaler_name, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): + if upscaler_name is None or upscaler_name == "None": + return + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None) + assert upscaler1, f'could not find upscaler named {upscaler_name}' + + pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False) + pp.info[f"Postprocess upscaler"] = upscaler1.name diff --git a/style.css b/style.css index ec046f78..dd914104 100644 --- a/style.css +++ b/style.css @@ -164,7 +164,7 @@ min-height: 3.2em; } -#txt2img_styles ul, #img2img_styles ul{ +ul.list-none{ max-height: 35em; z-index: 2000; } @@ -714,9 +714,6 @@ footer { white-space: nowrap; min-width: auto; } -#txt2img_hires_fix{ - margin-left: -0.8em; -} #img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ margin-left: 0em; @@ -744,7 +741,6 @@ footer { .dark .gr-compact{ background-color: rgb(31 41 55 / var(--tw-bg-opacity)); - margin-left: 0.8em; } .gr-compact{ -- cgit v1.2.3 From d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:28:12 +0300 Subject: remove the need to place configs near models --- configs/instruct-pix2pix.yaml | 99 +++++++++++++++ configs/v1-inpainting-inference.yaml | 70 +++++++++++ modules/api/api.py | 5 +- modules/devices.py | 12 +- modules/sd_hijack_inpainting.py | 9 -- modules/sd_models.py | 228 +++++++++++++++++------------------ modules/sd_models_config.py | 65 ++++++++++ modules/shared.py | 7 +- modules/shared_items.py | 15 ++- modules/timer.py | 35 ++++++ v2-inference-v.yaml | 68 ----------- 11 files changed, 411 insertions(+), 202 deletions(-) create mode 100644 configs/instruct-pix2pix.yaml create mode 100644 configs/v1-inpainting-inference.yaml create mode 100644 modules/sd_models_config.py create mode 100644 modules/timer.py delete mode 100644 v2-inference-v.yaml (limited to 'modules/shared.py') diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml new file mode 100644 index 00000000..437ddcef --- /dev/null +++ b/configs/instruct-pix2pix.yaml @@ -0,0 +1,99 @@ +# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). +# See more details in LICENSE. + +model: + base_learning_rate: 1.0e-04 + target: modules.models.diffusion.ddpm_edit.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: edited + cond_stage_key: edit + # image_size: 64 + # image_size: 32 + image_size: 16 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: true + load_ema: true + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 0 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + +data: + target: main.DataModuleFromConfig + params: + batch_size: 128 + num_workers: 1 + wrap: false + validation: + target: edit_dataset.EditDataset + params: + path: data/clip-filtered-dataset + cache_dir: data/ + cache_name: data_10k + split: val + min_text_sim: 0.2 + min_image_sim: 0.75 + min_direction_sim: 0.2 + max_samples_per_prompt: 1 + min_resize_res: 512 + max_resize_res: 512 + crop_res: 512 + output_as_edit: False + real_input: True diff --git a/configs/v1-inpainting-inference.yaml b/configs/v1-inpainting-inference.yaml new file mode 100644 index 00000000..f9eec37d --- /dev/null +++ b/configs/v1-inpainting-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 7.5e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important + monitor: val/loss_simple_ema + scale_factor: 0.18215 + finetune_keys: null + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 # 4 data + 4 downscaled image + 1 mask + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/modules/api/api.py b/modules/api/api.py index 25c65e57..eb7b1da5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, find_checkpoint_config +from modules.sd_models import checkpoints_list +from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List @@ -387,7 +388,7 @@ class Api: ] def get_sd_models(self): - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..2d5f797a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,14 +34,18 @@ def get_cuda_device_string(): return "cuda" -def get_optimal_device(): +def get_optimal_device_name(): if torch.cuda.is_available(): - return torch.device(get_cuda_device_string()) + return get_cuda_device_string() if has_mps(): - return torch.device("mps") + return "mps" + + return "cpu" - return cpu + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) def get_device_for(task): diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 31d2c898..478cd499 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0, e_t -def should_hijack_inpainting(checkpoint_info): - from modules import sd_models - - ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() - - return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename - - def do_inpainting_hijack(): # p_sample_plms is needed because PLMS can't work with dicts as conditionings diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072eb2e..fa208728 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,8 +2,6 @@ import collections import os.path import sys import gc -import time -from collections import namedtuple import torch import re import safetensors.torch @@ -14,10 +12,10 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path -from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting -from modules.sd_hijack_ip2p import should_hijack_ip2p +from modules.sd_hijack_inpainting import do_inpainting_hijack +from modules.timer import Timer model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -99,17 +97,6 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) -def find_checkpoint_config(info): - if info is None: - return shared.cmd_opts.config - - config = os.path.splitext(info.filename)[0] + ".yaml" - if os.path.exists(config): - return config - - return shared.cmd_opts.config - - def list_models(): checkpoints_list.clear() checkpoint_alisases.clear() @@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location - if device is None: - device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) @@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo): +def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + if checkpoint_info in checkpoints_loaded: + # use checkpoint cache + print(f"Loading weights [{sd_model_hash}] from cache") + return checkpoints_loaded[checkpoint_info] + + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + res = read_state_dict(checkpoint_info.filename) + timer.record("load weights from disk") + + return res + + +def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + if checkpoint_info.title != title: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - cache_enabled = shared.opts.sd_checkpoint_cache > 0 + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if cache_enabled and checkpoint_info in checkpoints_loaded: - # use checkpoint cache - print(f"Loading weights [{sd_model_hash}] from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - else: - # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + model.load_state_dict(state_dict, strict=False) + del state_dict + timer.record("apply weights to model") - sd = read_state_dict(checkpoint_info.filename) - model.load_state_dict(sd, strict=False) - del sd - - if cache_enabled: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + if shared.opts.sd_checkpoint_cache > 0: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) + timer.record("apply channels_last") - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if not shared.cmd_opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) - if not shared.cmd_opts.no_half: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.cmd_opts.upcast_sampling and depth_model: - model.depth_model = None + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model - model.half() - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model + timer.record("apply half()") - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype_unet = model.model.diffusion_model.dtype - devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - model.first_stage_model.to(devices.dtype_vae) + model.first_stage_model.to(devices.dtype_vae) + timer.record("apply dtype to VAE") # clean up cache if limit is reached - if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model - checkpoints_loaded.popitem(last=False) # LRU + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename @@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_vae.clear_loaded_vae() vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) sd_vae.load_vae(model, vae_file, vae_source) + timer.record("load VAE") def enable_midas_autodownload(): @@ -340,24 +340,20 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper -class Timer: - def __init__(self): - self.start = time.time() +def repair_config(sd_config): - def elapsed(self): - end = time.time() - res = end - self.start - self.start = end - return res + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True -def load_model(checkpoint_info=None): + +def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - checkpoint_config = find_checkpoint_config(checkpoint_info) - - if checkpoint_config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -365,38 +361,27 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_config) - - if should_hijack_inpainting(checkpoint_info): - # Hardcoded config for now... - sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.unet_config.params.in_channels = 9 - sd_config.model.params.finetune_keys = None - - if should_hijack_ip2p(checkpoint_info): - sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.first_stage_key = "edited" - sd_config.model.params.cond_stage_key = "edit" - sd_config.model.params.image_size = 16 - sd_config.model.params.unet_config.params.in_channels = 8 - sd_config.model.params.unet_config.params.out_channels = 4 + do_inpainting_hijack() - if not hasattr(sd_config.model.params, "use_ema"): - sd_config.model.params.use_ema = False + timer = Timer() - do_inpainting_hijack() + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - timer = Timer() + timer.record("find config") - sd_model = None + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + + timer.record("load config") + + print(f"Creating model from config: {checkpoint_config}") + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) @@ -407,29 +392,35 @@ def load_model(checkpoint_info=None): print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) - elapsed_create = timer.elapsed() + sd_model.used_config = checkpoint_config - load_model_weights(sd_model, checkpoint_info) + timer.record("create model") - elapsed_load_weights = timer.elapsed() + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: sd_model.to(shared.device) + timer.record("move model to device") + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + sd_model.eval() shared.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + timer.record("load textual inversion embeddings") + script_callbacks.model_loaded_callback(sd_model) - elapsed_the_rest = timer.elapsed() + timer.record("scripts callbacks") - print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") + print(f"Model loaded in {timer.summary()}.") return sd_model @@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed current_checkpoint_info = None else: @@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - checkpoint_config = find_checkpoint_config(current_checkpoint_info) - - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): - del sd_model - checkpoints_loaded.clear() - load_model(checkpoint_info) - return shared.sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None): timer = Timer() + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + + timer.record("find config") + + if sd_model is None or checkpoint_config != sd_model.used_config: + del sd_model + checkpoints_loaded.clear() + load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) + return shared.sd_model + try: - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, state_dict, timer) except Exception as e: print("Failed to load checkpoint, restoring previous") - load_model_weights(sd_model, current_checkpoint_info) + load_model_weights(sd_model, current_checkpoint_info, None, timer) raise finally: sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) + timer.record("move model to device") - elapsed = timer.elapsed() - - print(f"Weights loaded in {elapsed:.1f}s.") + print(f"Weights loaded in {timer.summary()}.") return sd_model diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py new file mode 100644 index 00000000..ea773a10 --- /dev/null +++ b/modules/sd_models_config.py @@ -0,0 +1,65 @@ +import re +import os + +from modules import shared, paths + +sd_configs_path = shared.sd_configs_path +sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") + + +config_default = shared.sd_default_config +config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") +config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") +config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") + +re_parametrization_v = re.compile(r'-v\b') + + +def guess_model_config_from_state_dict(sd, filename): + fn = os.path.basename(filename) + + sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) + diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) + roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) + + if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: + if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: + return config_sd2v + else: + return config_sd2 + + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return config_inpainting + if diffusion_model_input.shape[1] == 8: + return config_instruct_pix2pix + + if roberta_weight is not None: + return config_alt_diffusion + + return config_default + + +def find_checkpoint_config(state_dict, info): + if info is None: + return guess_model_config_from_state_dict(state_dict, "") + + config = find_checkpoint_config_near_filename(info) + if config is not None: + return config + + return guess_model_config_from_state_dict(state_dict, info.filename) + + +def find_checkpoint_config_near_filename(info): + if info is None: + return None + + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return None + diff --git a/modules/shared.py b/modules/shared.py index cdeed55d..14be993d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,13 +13,14 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules import localization, extensions, script_loading, errors, ui_components, shared_items from modules.paths import models_path, script_path demo = None -sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") +sd_configs_path = os.path.join(script_path, "configs") +sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file @@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), diff --git a/modules/shared_items.py b/modules/shared_items.py index b5d480c9..8b5ec96d 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -4,7 +4,20 @@ def realesrgan_models_names(): import modules.realesrgan_model return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + def postprocessing_scripts(): import modules.scripts - return modules.scripts.scripts_postproc.scripts \ No newline at end of file + return modules.scripts.scripts_postproc.scripts + + +def sd_vae_items(): + import modules.sd_vae + + return ["Automatic", "None"] + list(modules.sd_vae.vae_dict) + + +def refresh_vae_list(): + import modules.sd_vae + + return modules.sd_vae.refresh_vae_list diff --git a/modules/timer.py b/modules/timer.py new file mode 100644 index 00000000..57a4f17a --- /dev/null +++ b/modules/timer.py @@ -0,0 +1,35 @@ +import time + + +class Timer: + def __init__(self): + self.start = time.time() + self.records = {} + self.total = 0 + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def record(self, category, extra_time=0): + e = self.elapsed() + if category not in self.records: + self.records[category] = 0 + + self.records[category] += e + extra_time + self.total += e + extra_time + + def summary(self): + res = f"{self.total:.1f}s" + + additions = [x for x in self.records.items() if x[1] >= 0.1] + if not additions: + return res + + res += " (" + res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) + res += ")" + + return res diff --git a/v2-inference-v.yaml b/v2-inference-v.yaml deleted file mode 100644 index 513cd635..00000000 --- a/v2-inference-v.yaml +++ /dev/null @@ -1,68 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - parameterization: "v" - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - use_fp16: True - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" \ No newline at end of file -- cgit v1.2.3 From 5eee2ac39863f9e44591b50d0710dd2615416a13 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 17:15:42 +0100 Subject: add data-dir flag and set all user data directories based on it --- modules/extensions.py | 2 +- modules/generation_parameters_copypaste.py | 4 ++-- modules/gfpgan_model.py | 5 ++--- modules/hashes.py | 4 +++- modules/interrogate.py | 2 +- modules/paths.py | 10 +++++++++- modules/processing.py | 3 ++- modules/sd_models.py | 6 +++--- modules/sd_vae.py | 5 ++--- modules/shared.py | 11 ++++++----- modules/textual_inversion/preprocess.py | 5 ++--- modules/ui.py | 6 +++--- modules/ui_extensions.py | 2 +- modules/upscaler.py | 5 ++--- 14 files changed, 39 insertions(+), 31 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/extensions.py b/modules/extensions.py index b522125c..92ee8144 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -7,7 +7,7 @@ import git from modules import paths, shared extensions = [] -extensions_dir = os.path.join(paths.script_path, "extensions") +extensions_dir = os.path.join(paths.data_path, "extensions") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 46e12dc6..35f72808 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.shared import script_path +from modules.paths import data_path, script_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -289,7 +289,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model def connect_paste(button, paste_fields, input_comp, jsfunc=None): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: - filename = os.path.join(script_path, "params.txt") + filename = os.path.join(data_path, "params.txt") if os.path.exists(filename): with open(filename, "r", encoding="utf8") as file: prompt = file.read() diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 1e2dbc32..fbe6215a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -6,12 +6,11 @@ import facexlib import gfpgan import modules.face_restoration -from modules import shared, devices, modelloader -from modules.paths import models_path +from modules import paths, shared, devices, modelloader model_dir = "GFPGAN" user_path = None -model_path = os.path.join(models_path, model_dir) +model_path = os.path.join(paths.models_path, model_dir) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" have_gfpgan = False loaded_gfpgan_model = None diff --git a/modules/hashes.py b/modules/hashes.py index b85a7580..819362a3 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -4,8 +4,10 @@ import os.path import filelock +from modules.paths import data_path -cache_filename = "cache.json" + +cache_filename = os.path.join(data_path, "cache.json") cache_data = None diff --git a/modules/interrogate.py b/modules/interrogate.py index c72ff694..cbb80683 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -12,7 +12,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram, modelloader, errors +from modules import devices, paths, shared, lowvram, modelloader, errors blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' diff --git a/modules/paths.py b/modules/paths.py index 20b3e4d8..08e6f9b9 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -4,7 +4,15 @@ import sys import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -models_path = os.path.join(script_path, "models") + +# Parse the --data-dir flag first so we can use it as a base for our other argument default values +parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) +cmd_opts_pre = parser.parse_known_args()[0] +data_path = cmd_opts_pre.data_dir +models_path = os.path.join(data_path, "models") + +# data_path = cmd_opts_pre.data sys.path.insert(0, script_path) # search for directory of stable diffusion in following places diff --git a/modules/processing.py b/modules/processing.py index 262806a1..5072fc40 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -17,6 +17,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, gener from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared +import modules.paths as paths import modules.face_restoration import modules.images as images import modules.styles @@ -584,7 +585,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if not p.disable_extra_networks: extra_networks.activate(p, extra_network_data) - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) diff --git a/modules/sd_models.py b/modules/sd_models.py index 37dad18d..b2d48a51 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -12,13 +12,13 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) +model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} checkpoint_alisases = {} @@ -307,7 +307,7 @@ def enable_midas_autodownload(): location automatically. """ - midas_path = os.path.join(models_path, 'midas') + midas_path = os.path.join(paths.models_path, 'midas') # stable-diffusion-stability-ai hard-codes the midas model path to # a location that differs from where other scripts using this model look. diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 4ce238b8..9b00f76e 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,13 +3,12 @@ import safetensors.torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks, sd_models -from modules.paths import models_path +from modules import paths, shared, devices, script_callbacks, sd_models import glob from copy import deepcopy -vae_path = os.path.abspath(os.path.join(models_path, "VAE")) +vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} vae_dict = {} diff --git a/modules/shared.py b/modules/shared.py index 14be993d..474fcc42 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.memmon import modules.styles import modules.devices as devices from modules import localization, extensions, script_loading, errors, ui_components, shared_items -from modules.paths import models_path, script_path +from modules.paths import models_path, script_path, data_path demo = None @@ -25,6 +25,7 @@ sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") @@ -35,7 +36,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") -parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") @@ -74,16 +75,16 @@ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for sp parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) -parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) +parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) -parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) +parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") -parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) +parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index c0ac11d3..2239cb84 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -6,8 +6,7 @@ import sys import tqdm import time -from modules import shared, images, deepbooru -from modules.paths import models_path +from modules import paths, shared, images, deepbooru from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop @@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre dnn_model_path = None try: - dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv")) + dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) except Exception as e: print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..0117df3e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path +from modules.paths import script_path, data_path from modules.shared import opts, cmd_opts, restricted_opts @@ -1497,8 +1497,8 @@ def create_ui(): with open(cssfile, "r", encoding="utf8") as file: css += file.read() + "\n" - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + if os.path.exists(os.path.join(data_path, "user.css")): + with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file: css += file.read() + "\n" if not cmd_opts.no_progressbar_hiding: diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 742e745e..66a41865 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -132,7 +132,7 @@ def install_extension_from_url(dirname, url): normalized_url = normalize_git_url(url) assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' - tmpdir = os.path.join(paths.script_path, "tmp", dirname) + tmpdir = os.path.join(paths.data_path, "tmp", dirname) try: shutil.rmtree(tmpdir, True) diff --git a/modules/upscaler.py b/modules/upscaler.py index a5bf5acb..e2eaa730 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -11,7 +11,6 @@ from modules import modelloader, shared LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) -from modules.paths import models_path class Upscaler: @@ -39,7 +38,7 @@ class Upscaler: self.mod_scale = None if self.model_path is None and self.name: - self.model_path = os.path.join(models_path, self.name) + self.model_path = os.path.join(shared.models_path, self.name) if self.model_path and create_dirs: os.makedirs(self.model_path, exist_ok=True) @@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler): def __init__(self, dirname=None): super().__init__(False) self.name = "Nearest" - self.scalers = [UpscalerData("Nearest", None, self)] \ No newline at end of file + self.scalers = [UpscalerData("Nearest", None, self)] -- cgit v1.2.3