diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-18 20:04:24 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-18 20:04:24 +0000 |
commit | 924e222004ab54273806c5f2ca7a0e7cfa76ad83 (patch) | |
tree | 153a08105ee2bc87df43a8a1423df96d25a8e19b /modules/textual_inversion | |
parent | 889b851a5260ce869a3286ad15d17d1bbb1da0a7 (diff) | |
download | stable-diffusion-webui-gfx803-924e222004ab54273806c5f2ca7a0e7cfa76ad83.tar.gz stable-diffusion-webui-gfx803-924e222004ab54273806c5f2ca7a0e7cfa76ad83.tar.bz2 stable-diffusion-webui-gfx803-924e222004ab54273806c5f2ca7a0e7cfa76ad83.zip |
add option to show/hide warnings
removed hiding warnings from LDSR
fixed/reworked few places that produced warnings
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7e4a6d24..5a7be422 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -15,7 +15,7 @@ import numpy as np from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}<br/> pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
+ sd_hijack_checkpoint.remove()
return embedding, filename
+
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|