diff options
-rw-r--r-- | .github/workflows/run_tests.yaml | 2 | ||||
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 7 | ||||
-rw-r--r-- | modules/shared.py | 1 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 35 | ||||
-rw-r--r-- | modules/ui.py | 11 | ||||
-rw-r--r-- | test/advanced_features/__init__.py | 0 | ||||
-rw-r--r-- | test/advanced_features/extras_test.py | 29 | ||||
-rw-r--r-- | test/advanced_features/txt2img_test.py | 47 | ||||
-rw-r--r-- | test/basic_features/extras_test.py | 54 | ||||
-rw-r--r-- | test/basic_features/img2img_test.py | 7 | ||||
-rw-r--r-- | test/basic_features/txt2img_test.py | 11 | ||||
-rw-r--r-- | test/basic_features/utils_test.py | 17 | ||||
-rw-r--r-- | test/server_poll.py | 2 | ||||
-rw-r--r-- | webui.py | 3 |
14 files changed, 135 insertions, 91 deletions
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index ecb9012a..be7ffa23 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -18,7 +18,7 @@ jobs: cache-dependency-path: | **/requirements*txt - name: Run tests - run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test + run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test - name: Upload main app stdout-stderr uses: actions/upload-artifact@v3 if: always() diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 32c67ccc..ea3f1db9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,6 +24,7 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
diff --git a/modules/shared.py b/modules/shared.py index a1e10201..aa37c8ce 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 14be2c96..5420903f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -2,6 +2,7 @@ import os import sys
import traceback
import inspect
+from collections import namedtuple
import torch
import tqdm
@@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
-from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
- insert_image_data_embed, extract_image_data_embed,
- caption_image_overlay)
+from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
from modules.textual_inversion.logging import save_settings_to_file
+TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
+textual_inversion_templates = {}
+
+
+def list_textual_inversion_templates():
+ textual_inversion_templates.clear()
+
+ for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for fn in fns:
+ path = os.path.join(root, fn)
+
+ textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
+
+ return textual_inversion_templates
+
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
@@ -274,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): })
-def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
@@ -284,8 +299,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- assert template_file, "Prompt template file is empty"
- assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert template_filename, "Prompt template file not selected"
+ assert template_file, f"Prompt template file {template_filename} not found"
+ assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
assert steps > 0, "Max steps must be positive"
@@ -296,10 +312,13 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
- validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = textual_inversion_templates.get(template_filename, None)
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = template_file.path
shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
diff --git a/modules/ui.py b/modules/ui.py index ddfe1b1a..b6079aec 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -37,7 +37,7 @@ from modules import prompt_parser from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
-import modules.textual_inversion.ui
+from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
@@ -1322,6 +1322,9 @@ def create_ui(): outputs=[process_focal_crop_row],
)
+ def get_textual_inversion_template_names():
+ return sorted([x for x in textual_inversion.textual_inversion_templates])
+
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with FormRow():
@@ -1345,7 +1348,11 @@ def create_ui(): dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
- template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
+
+ with FormRow():
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
+
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
diff --git a/test/advanced_features/__init__.py b/test/advanced_features/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/test/advanced_features/__init__.py +++ /dev/null diff --git a/test/advanced_features/extras_test.py b/test/advanced_features/extras_test.py deleted file mode 100644 index 8763f8ed..00000000 --- a/test/advanced_features/extras_test.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest - - -class TestExtrasWorking(unittest.TestCase): - def setUp(self): - self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image" - self.simple_extras = { - "resize_mode": 0, - "show_extras_results": True, - "gfpgan_visibility": 0, - "codeformer_visibility": 0, - "codeformer_weight": 0, - "upscaling_resize": 2, - "upscaling_resize_w": 128, - "upscaling_resize_h": 128, - "upscaling_crop": True, - "upscaler_1": "None", - "upscaler_2": "None", - "extras_upscaler_2_visibility": 0, - "image": "" - } - - -class TestExtrasCorrectness(unittest.TestCase): - pass - - -if __name__ == "__main__": - unittest.main() diff --git a/test/advanced_features/txt2img_test.py b/test/advanced_features/txt2img_test.py deleted file mode 100644 index 36ed7b9a..00000000 --- a/test/advanced_features/txt2img_test.py +++ /dev/null @@ -1,47 +0,0 @@ -import unittest -import requests - - -class TestTxt2ImgWorking(unittest.TestCase): - def setUp(self): - self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" - self.simple_txt2img = { - "enable_hr": False, - "denoising_strength": 0, - "firstphase_width": 0, - "firstphase_height": 0, - "prompt": "example prompt", - "styles": [], - "seed": -1, - "subseed": -1, - "subseed_strength": 0, - "seed_resize_from_h": -1, - "seed_resize_from_w": -1, - "batch_size": 1, - "n_iter": 1, - "steps": 3, - "cfg_scale": 7, - "width": 64, - "height": 64, - "restore_faces": False, - "tiling": False, - "negative_prompt": "", - "eta": 0, - "s_churn": 0, - "s_tmax": 0, - "s_tmin": 0, - "s_noise": 1, - "sampler_index": "Euler a" - } - - def test_txt2img_with_restore_faces_performed(self): - self.simple_txt2img["restore_faces"] = True - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - -class TestTxt2ImgCorrectness(unittest.TestCase): - pass - - -if __name__ == "__main__": - unittest.main() diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py new file mode 100644 index 00000000..0170c511 --- /dev/null +++ b/test/basic_features/extras_test.py @@ -0,0 +1,54 @@ +import unittest +import requests +from gradio.processing_utils import encode_pil_to_base64 +from PIL import Image + +class TestExtrasWorking(unittest.TestCase): + def setUp(self): + self.url_extras_single = "http://localhost:7860/sdapi/v1/extra-single-image" + self.extras_single = { + "resize_mode": 0, + "show_extras_results": True, + "gfpgan_visibility": 0, + "codeformer_visibility": 0, + "codeformer_weight": 0, + "upscaling_resize": 2, + "upscaling_resize_w": 128, + "upscaling_resize_h": 128, + "upscaling_crop": True, + "upscaler_1": "None", + "upscaler_2": "None", + "extras_upscaler_2_visibility": 0, + "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) + } + + def test_simple_upscaling_performed(self): + self.extras_single["upscaler_1"] = "Lanczos" + self.assertEqual(requests.post(self.url_extras_single, json=self.extras_single).status_code, 200) + + +class TestPngInfoWorking(unittest.TestCase): + def setUp(self): + self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image" + self.png_info = { + "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) + } + + def test_png_info_performed(self): + self.assertEqual(requests.post(self.url_png_info, json=self.png_info).status_code, 200) + + +class TestInterrogateWorking(unittest.TestCase): + def setUp(self): + self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image" + self.interrogate = { + "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")), + "model": "clip" + } + + def test_interrogate_performed(self): + self.assertEqual(requests.post(self.url_interrogate, json=self.interrogate).status_code, 200) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py index bd520b13..08c5c903 100644 --- a/test/basic_features/img2img_test.py +++ b/test/basic_features/img2img_test.py @@ -16,7 +16,7 @@ class TestImg2ImgWorking(unittest.TestCase): "inpainting_fill": 0, "inpaint_full_res": False, "inpaint_full_res_padding": 0, - "inpainting_mask_invert": 0, + "inpainting_mask_invert": False, "prompt": "example prompt", "styles": [], "seed": -1, @@ -50,6 +50,11 @@ class TestImg2ImgWorking(unittest.TestCase): self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + def test_inpainting_with_inverted_masked_performed(self): + self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) + self.simple_img2img["inpainting_mask_invert"] = True + self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + def test_img2img_sd_upscale_performed(self): self.simple_img2img["script_name"] = "sd upscale" self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0] diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index 1c2674b2..5b27a7ec 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -41,6 +41,9 @@ class TestTxt2ImgWorking(unittest.TestCase): self.simple_txt2img["negative_prompt"] = "example negative prompt" self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + def test_txt2img_with_complex_prompt_performed(self): + self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]" + def test_txt2img_not_square_image_performed(self): self.simple_txt2img["height"] = 128 self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) @@ -53,6 +56,10 @@ class TestTxt2ImgWorking(unittest.TestCase): self.simple_txt2img["tiling"] = True self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + def test_txt2img_with_restore_faces_performed(self): + self.simple_txt2img["restore_faces"] = True + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + def test_txt2img_with_vanilla_sampler_performed(self): self.simple_txt2img["sampler_index"] = "PLMS" self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) @@ -63,6 +70,10 @@ class TestTxt2ImgWorking(unittest.TestCase): self.simple_txt2img["n_iter"] = 2 self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + def test_txt2img_batch_performed(self): + self.simple_txt2img["batch_size"] = 2 + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + if __name__ == "__main__": unittest.main() diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py index 765470c9..94e00253 100644 --- a/test/basic_features/utils_test.py +++ b/test/basic_features/utils_test.py @@ -14,10 +14,25 @@ class UtilsTests(unittest.TestCase): self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories" self.url_artists = "http://localhost:7860/sdapi/v1/artists" + self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings" def test_options_get(self): self.assertEqual(requests.get(self.url_options).status_code, 200) + def test_options_write(self): + response = requests.get(self.url_options) + self.assertEqual(response.status_code, 200) + + pre_value = response.json()["send_seed"] + + self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200) + + response = requests.get(self.url_options) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["send_seed"], not pre_value) + + requests.post(self.url_options, json={"send_seed": pre_value}) + def test_cmd_flags(self): self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) @@ -48,6 +63,8 @@ class UtilsTests(unittest.TestCase): def test_artists(self): self.assertEqual(requests.get(self.url_artists).status_code, 200) + def test_embeddings(self): + self.assertEqual(requests.get(self.url_artists).status_code, 200) if __name__ == "__main__": unittest.main() diff --git a/test/server_poll.py b/test/server_poll.py index d4df697b..42d56a4c 100644 --- a/test/server_poll.py +++ b/test/server_poll.py @@ -15,7 +15,7 @@ def run_tests(proc, test_dir): break if proc.poll() is None: if test_dir is None: - test_dir = "" + test_dir = "test" suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") result = unittest.TextTestRunner(verbosity=2).run(suite) return len(result.failures) + len(result.errors) @@ -33,6 +33,7 @@ import modules.sd_models import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
+import modules.textual_inversion.textual_inversion
import modules.ui
from modules import modelloader
@@ -67,6 +68,8 @@ def initialize(): modules.sd_vae.refresh_vae_list()
+ modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
+
try:
modules.sd_models.load_model()
except Exception as e:
|