aboutsummaryrefslogtreecommitdiffstats
path: root/modules/safety.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-13 05:34:41 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-09-13 05:34:41 +0000
commitb5a8b99d3fd3a233da3835f193afe35818dc24b1 (patch)
tree39c7763cd6bc70f97c24a3c160c3ae5a56038a64 /modules/safety.py
parentb03bc4e79ab1a896ffb7295b4c3a0c30868a2c4a (diff)
downloadstable-diffusion-webui-gfx803-b5a8b99d3fd3a233da3835f193afe35818dc24b1.tar.gz
stable-diffusion-webui-gfx803-b5a8b99d3fd3a233da3835f193afe35818dc24b1.tar.bz2
stable-diffusion-webui-gfx803-b5a8b99d3fd3a233da3835f193afe35818dc24b1.zip
put safety checker into a separate file because it's already crowded in processing
Diffstat (limited to 'modules/safety.py')
-rw-r--r--modules/safety.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/modules/safety.py b/modules/safety.py
new file mode 100644
index 00000000..cff4b278
--- /dev/null
+++ b/modules/safety.py
@@ -0,0 +1,42 @@
+import torch
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from transformers import AutoFeatureExtractor
+from PIL import Image
+
+import modules.shared as shared
+
+safety_model_id = "CompVis/stable-diffusion-safety-checker"
+safety_feature_extractor = None
+safety_checker = None
+
+def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+# check and replace nsfw content
+def check_safety(x_image):
+ global safety_feature_extractor, safety_checker
+
+ if safety_feature_extractor is None:
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
+
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
+
+ return x_checked_image, has_nsfw_concept
+
+
+def censor_batch(x):
+ x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
+ x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
+ x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
+
+ return x