From b5a8b99d3fd3a233da3835f193afe35818dc24b1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 13 Sep 2022 08:34:41 +0300 Subject: put safety checker into a separate file because it's already crowded in processing --- modules/safety.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 modules/safety.py (limited to 'modules/safety.py') diff --git a/modules/safety.py b/modules/safety.py new file mode 100644 index 00000000..cff4b278 --- /dev/null +++ b/modules/safety.py @@ -0,0 +1,42 @@ +import torch +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor +from PIL import Image + +import modules.shared as shared + +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = None +safety_checker = None + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + +# check and replace nsfw content +def check_safety(x_image): + global safety_feature_extractor, safety_checker + + if safety_feature_extractor is None: + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + + return x_checked_image, has_nsfw_concept + + +def censor_batch(x): + x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) + x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + return x -- cgit v1.2.3