diff options
author | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-08-04 06:38:16 +0000 |
---|---|---|
committer | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-08-04 06:38:16 +0000 |
commit | 70e66e81e56f0bb187395878ee49705acbfab40c (patch) | |
tree | b7745d409c7c9192eec0fe0722acf316f8342f79 | |
parent | c134a480164bef017cd4b33fae57a31a86556beb (diff) | |
parent | f0c1063a707a4a43823b0ed00e2a8eeb22a9ed0a (diff) | |
download | stable-diffusion-webui-gfx803-70e66e81e56f0bb187395878ee49705acbfab40c.tar.gz stable-diffusion-webui-gfx803-70e66e81e56f0bb187395878ee49705acbfab40c.tar.bz2 stable-diffusion-webui-gfx803-70e66e81e56f0bb187395878ee49705acbfab40c.zip |
Merge branch 'dev' into efficient-vae-methods
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 5 | ||||
-rw-r--r-- | modules/processing.py | 8 | ||||
-rw-r--r-- | modules/sd_hijack.py | 6 | ||||
-rw-r--r-- | modules/sd_samplers_common.py | 6 | ||||
-rw-r--r-- | modules/sd_vae.py | 3 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 4 |
6 files changed, 18 insertions, 14 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c4821d21..70f1cbd2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -10,7 +10,7 @@ import torch import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
+from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
- from modules import images
+ from modules import images, processing
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
diff --git a/modules/processing.py b/modules/processing.py index 099d86b7..aae39866 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
+decode_first_stage = sd_samplers_common.decode_first_stage
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -572,13 +573,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): return samples
-def decode_first_stage(model, x):
- from modules.sd_samplers_common import samples_to_images_tensor, approximation_indexes
- x = x.to(devices.dtype_vae)
- approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
- return samples_to_images_tensor(x, approx_index, model)
-
-
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfa5f0eb..609fd56c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,7 +2,6 @@ import torch from torch.nn.functional import silu
from types import MethodType
-import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
@@ -164,12 +163,13 @@ class StableDiffusionModelHijack: clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
def __init__(self):
+ import modules.textual_inversion.textual_inversion
+
self.extra_generation_params = {}
self.comments = []
+ self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index d444cac1..2cfa4ac6 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -54,6 +54,12 @@ def single_sample_to_image(sample, approximation=None): return Image.fromarray(x_sample)
+def decode_first_stage(model, x):
+ x = model.decode_first_stage(x.to(devices.dtype_vae))
+
+ return x
+
+
def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e4ff2994..84271db0 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -50,6 +50,7 @@ def get_filename(filepath): def refresh_vae_list(): + global vae_dict vae_dict.clear() paths = [ @@ -83,6 +84,8 @@ def refresh_vae_list(): name = get_filename(filepath) vae_dict[name] = filepath + vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))) + def find_vae_near_checkpoint(checkpoint_file): checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0] diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4713bc2d..aa79dc09 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
+from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -387,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ from modules import processing
+
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
template_file = textual_inversion_templates.get(template_filename, None)
|