diff options
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 48 |
1 files changed, 23 insertions, 25 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8314450a..f1248bb7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -14,10 +14,12 @@ import torch from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnSchedule
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
+ multiplier = 1.0
+
def __init__(self, dim, state_dict=None):
super().__init__()
@@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module): self.to(devices.device)
def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
+ return x + (self.linear2(self.linear1(x))) * self.multiplier
+
+
+def apply_strength(value=None):
+ HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
class Hypernetwork:
@@ -223,31 +229,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps:
return hypernetwork, filename
- schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(schedules)
- print(f'Training at rate of {learn_rate} until step {end_step}')
-
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, (x, text, cond) in pbar:
+ for i, entry in pbar:
hypernetwork.step = i + ititial_step
- if hypernetwork.step > end_step:
- try:
- (learn_rate, end_step) = next(schedules)
- except Exception:
- break
- tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = learn_rate
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- cond = cond.to(devices.device)
- x = x.to(devices.device)
+ cond = entry.cond.to(devices.device)
+ x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
del cond
@@ -267,7 +265,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -282,16 +280,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, )
processed = processing.process_images(p)
- image = processed.images[0]
+ image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- shared.state.current_image = image
- image.save(last_saved_image)
-
- last_saved_image += f", prompt: {preview_text}"
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
@@ -299,7 +297,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, <p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
|