From b6763fb8847df5a5678f37137e7a702569e5c925 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 5 Sep 2022 23:08:06 +0300 Subject: added random artist button added a setting for padding when doing inpaint at original resolution --- modules/artists.py | 25 +++++++++++++++++++++++++ modules/processing.py | 2 +- modules/shared.py | 7 +++++++ modules/ui.py | 19 +++++++++++++++++++ 4 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 modules/artists.py (limited to 'modules') diff --git a/modules/artists.py b/modules/artists.py new file mode 100644 index 00000000..3612758b --- /dev/null +++ b/modules/artists.py @@ -0,0 +1,25 @@ +import os.path +import csv +from collections import namedtuple + +Artist = namedtuple("Artist", ['name', 'weight', 'category']) + + +class ArtistsDatabase: + def __init__(self, filename): + self.cats = set() + self.artists = [] + + if not os.path.exists(filename): + return + + with open(filename, "r", newline='', encoding="utf8") as file: + reader = csv.DictReader(file) + + for row in reader: + artist = Artist(row["artist"], float(row["score"]), row["category"]) + self.artists.append(artist) + self.cats.add(artist.category) + + def categories(self): + return sorted(self.cats) diff --git a/modules/processing.py b/modules/processing.py index b744aa87..c0c1adb7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -324,7 +324,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpaint_full_res: self.mask_for_overlay = self.image_mask mask = self.image_mask.convert('L') - crop_region = get_crop_region(np.array(mask), 64) + crop_region = get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding) x1, y1, x2, y2 = crop_region mask = mask.crop(crop_region) diff --git a/modules/shared.py b/modules/shared.py index 70946fea..4e36df37 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,6 +4,7 @@ import os import gradio as gr import torch +import modules.artists from modules.paths import script_path, sd_path config_filename = "config.json" @@ -47,6 +48,8 @@ class State: state = State() +artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) + class Options: class OptionInfo: @@ -84,6 +87,8 @@ class Options: "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscaling. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscaling. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), + "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), + "upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}), } def __init__(self): @@ -122,3 +127,5 @@ if os.path.exists(config_filename): sd_upscalers = [] sd_model = None + + diff --git a/modules/ui.py b/modules/ui.py index ec583d14..aa5a61b7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -4,6 +4,7 @@ import io import json import mimetypes import os +import random import sys import time import traceback @@ -133,6 +134,13 @@ def wrap_gradio_call(func): return f +def roll_artist(prompt): + allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories]) + artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) + + return prompt + ", " + artist.name if prompt != '' else artist.name + + def visit(x, func, path=""): if hasattr(x, 'children'): for c in x.children: @@ -146,6 +154,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Row(): prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1) negative_prompt = gr.Textbox(label="Negative prompt", elem_id="txt2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1, visible=False) + roll = gr.Button('Roll', elem_id="txt2img_roll", visible=len(shared.artist_db.artists)>0) submit = gr.Button('Generate', elem_id="txt2img_generate", variant='primary') with gr.Row().style(equal_height=False): @@ -233,6 +242,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ] ) + roll.click( + fn=roll_artist, + inputs=[ + prompt, + ], + outputs=[ + prompt + ] + ) + with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Row(): prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1) -- cgit v1.2.3