diff options
author | Greg Fuller <gfuller23@gmail.com> | 2022-10-12 19:44:41 +0000 |
---|---|---|
committer | Greg Fuller <gfuller23@gmail.com> | 2022-10-12 19:44:41 +0000 |
commit | fb3cefb348600497d964c4fd3f99138d7dde0ec5 (patch) | |
tree | 984c02da13dae6ab0d6c21304a551bbfffea02f2 /modules/hypernetworks | |
parent | d717eb079cd6b7fa7a4f97c0a10d400bdec753fb (diff) | |
parent | 698d303b04e293635bfb49c525409f3bcf671dce (diff) | |
download | stable-diffusion-webui-gfx803-fb3cefb348600497d964c4fd3f99138d7dde0ec5.tar.gz stable-diffusion-webui-gfx803-fb3cefb348600497d964c4fd3f99138d7dde0ec5.tar.bz2 stable-diffusion-webui-gfx803-fb3cefb348600497d964c4fd3f99138d7dde0ec5.zip |
Merge remote-tracking branch 'upstream/master' into interrogate_include_ranks_in_output
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 53 |
1 files changed, 28 insertions, 25 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 470659df..b6c06d49 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -14,7 +14,7 @@ import torch from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnSchedule
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
@@ -120,6 +120,17 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None
+def find_closest_hypernetwork_name(search: str):
+ if not search:
+ return None
+ search = search.lower()
+ applicable = [name for name in shared.hypernetworks if search in name.lower()]
+ if not applicable:
+ return None
+ applicable = sorted(applicable, key=lambda name: len(name))
+ return applicable[0]
+
+
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
@@ -164,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
- assert hypernetwork_name, 'embedding not selected'
+ assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -212,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps:
return hypernetwork, filename
- schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(schedules)
- print(f'Training at rate of {learn_rate} until step {end_step}')
-
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, (x, text, cond) in pbar:
+ for i, entry in pbar:
hypernetwork.step = i + ititial_step
- if hypernetwork.step > end_step:
- try:
- (learn_rate, end_step) = next(schedules)
- except Exception:
- break
- tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = learn_rate
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- cond = cond.to(devices.device)
- x = x.to(devices.device)
+ cond = entry.cond.to(devices.device)
+ x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
del cond
@@ -256,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -271,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, )
processed = processing.process_images(p)
- image = processed.images[0]
+ image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- shared.state.current_image = image
- image.save(last_saved_image)
-
- last_saved_image += f", prompt: {preview_text}"
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
@@ -288,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, <p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
|