aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/images.py2
-rw-r--r--modules/processing.py5
-rw-r--r--modules/safety.py42
-rw-r--r--modules/shared.py1
4 files changed, 48 insertions, 2 deletions
diff --git a/modules/images.py b/modules/images.py
index ddd310a2..fc9a0113 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -299,7 +299,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if existing_info is not None:
for k, v in existing_info.items():
- pnginfo.add_text(k, v)
+ pnginfo.add_text(k, str(v))
pnginfo.add_text(pnginfo_section_name, info)
else:
diff --git a/modules/processing.py b/modules/processing.py
index 92bf66f2..65ae4846 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -24,7 +24,6 @@ opt_C = 4
opt_f = 8
-
class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
self.sd_model = sd_model
@@ -248,6 +247,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ if opts.filter_nsfw:
+ import modules.safety as safety
+ x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
+
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
diff --git a/modules/safety.py b/modules/safety.py
new file mode 100644
index 00000000..cff4b278
--- /dev/null
+++ b/modules/safety.py
@@ -0,0 +1,42 @@
+import torch
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from transformers import AutoFeatureExtractor
+from PIL import Image
+
+import modules.shared as shared
+
+safety_model_id = "CompVis/stable-diffusion-safety-checker"
+safety_feature_extractor = None
+safety_checker = None
+
+def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+# check and replace nsfw content
+def check_safety(x_image):
+ global safety_feature_extractor, safety_checker
+
+ if safety_feature_extractor is None:
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
+
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
+
+ return x_checked_image, has_nsfw_concept
+
+
+def censor_batch(x):
+ x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
+ x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
+ x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
+
+ return x
diff --git a/modules/shared.py b/modules/shared.py
index 891d7fb2..37c333f3 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -111,6 +111,7 @@ class Options:
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
"samples_save": OptionInfo(True, "Save indiviual samples"),
"samples_format": OptionInfo('png', 'File format for individual samples'),
+ "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
"grid_save": OptionInfo(True, "Save image grids"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"grid_format": OptionInfo('png', 'File format for grids'),