aboutsummaryrefslogtreecommitdiffstats
path: root/modules/safety.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-04 15:57:14 +0000
committerGitHub <noreply@github.com>2023-01-04 15:57:14 +0000
commit32547f2721c92794779e6ff9fb325243d5857cae (patch)
treed4d5f1a9705e59eef5029cc1be3bff57fbd389c2 /modules/safety.py
parentfe6e2362e8fa5d739de6997ab155a26686d20a49 (diff)
parent3dae545a03f5102ba5d9c3f27bb6241824c5a916 (diff)
downloadstable-diffusion-webui-gfx803-32547f2721c92794779e6ff9fb325243d5857cae.tar.gz
stable-diffusion-webui-gfx803-32547f2721c92794779e6ff9fb325243d5857cae.tar.bz2
stable-diffusion-webui-gfx803-32547f2721c92794779e6ff9fb325243d5857cae.zip
Merge branch 'master' into xygrid_infotext_improvements
Diffstat (limited to 'modules/safety.py')
-rw-r--r--modules/safety.py42
1 files changed, 0 insertions, 42 deletions
diff --git a/modules/safety.py b/modules/safety.py
deleted file mode 100644
index cff4b278..00000000
--- a/modules/safety.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import torch
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from transformers import AutoFeatureExtractor
-from PIL import Image
-
-import modules.shared as shared
-
-safety_model_id = "CompVis/stable-diffusion-safety-checker"
-safety_feature_extractor = None
-safety_checker = None
-
-def numpy_to_pil(images):
- """
- Convert a numpy image or a batch of images to a PIL image.
- """
- if images.ndim == 3:
- images = images[None, ...]
- images = (images * 255).round().astype("uint8")
- pil_images = [Image.fromarray(image) for image in images]
-
- return pil_images
-
-# check and replace nsfw content
-def check_safety(x_image):
- global safety_feature_extractor, safety_checker
-
- if safety_feature_extractor is None:
- safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
-
- safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
- x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
-
- return x_checked_image, has_nsfw_concept
-
-
-def censor_batch(x):
- x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
- x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
- x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
-
- return x