aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hypernetworks
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-14 17:31:49 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-10-14 17:31:49 +0000
commitc344ba3b325459abbf9b0df2c1b18f7bf99805b2 (patch)
treea55413118729c0ccabb51cd4b94bbe1ada508351 /modules/hypernetworks
parentbb295f54785ac36dc6aa6f7103a3431464440fc3 (diff)
downloadstable-diffusion-webui-gfx803-c344ba3b325459abbf9b0df2c1b18f7bf99805b2.tar.gz
stable-diffusion-webui-gfx803-c344ba3b325459abbf9b0df2c1b18f7bf99805b2.tar.bz2
stable-diffusion-webui-gfx803-c344ba3b325459abbf9b0df2c1b18f7bf99805b2.zip
add option to read generation params for learning previews from txt2img
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r--modules/hypernetworks/hypernetwork.py21
1 files changed, 16 insertions, 5 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index f1248bb7..e5cb1817 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -180,7 +180,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out)
-def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
+def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -265,20 +265,31 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
-
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=preview_text,
- steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
)
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entry.cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None