diff options
author | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-10-29 11:09:17 +0000 |
---|---|---|
committer | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-10-29 11:09:17 +0000 |
commit | ab27c111d06ec920791c73eea25ad9a61671852e (patch) | |
tree | 3ac74302db74a758eceef3f3b7bb74efd6077f32 /modules/hypernetworks/hypernetwork.py | |
parent | 35c45df28b303a05d56a13cb56d4046f08cf8c25 (diff) | |
download | stable-diffusion-webui-gfx803-ab27c111d06ec920791c73eea25ad9a61671852e.tar.gz stable-diffusion-webui-gfx803-ab27c111d06ec920791c73eea25ad9a61671852e.tar.bz2 stable-diffusion-webui-gfx803-ab27c111d06ec920791c73eea25ad9a61671852e.zip |
Add input validations before loading dataset for training
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 2e84583b..38f35c58 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -332,7 +332,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
- assert hypernetwork_name, 'hypernetwork not selected'
+ save_hypernetwork_every = save_hypernetwork_every or 0
+ create_image_every = create_image_every or 0
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -358,39 +360,43 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log else:
images_dir = None
+ hypernetwork = shared.loaded_hypernetwork
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ return hypernetwork, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- hypernetwork = shared.loaded_hypernetwork
- weights = hypernetwork.weights()
- for weight in weights:
- weight.requires_grad = True
-
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
-
- last_saved_file = "<none>"
- last_saved_image = "<none>"
- forced_filename = "<none>"
-
- ititial_step = hypernetwork.step or 0
- if ititial_step > steps:
- return hypernetwork, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+ forced_filename = "<none>"
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
|