aboutsummaryrefslogtreecommitdiffstats
path: root/modules/processing.py
diff options
context:
space:
mode:
authorGRMrGecko <grmrgecko@gmail.com>2022-09-13 00:15:35 +0000
committerGRMrGecko <grmrgecko@gmail.com>2022-09-13 00:15:35 +0000
commitfc18e2d48325afd902920ec519946c8b2a8019d0 (patch)
tree5c20e1483eb022062fd752d559bdee4e83de1e66 /modules/processing.py
parentfa8be8acd62894bfc96da985326fda3208266468 (diff)
downloadstable-diffusion-webui-gfx803-fc18e2d48325afd902920ec519946c8b2a8019d0.tar.gz
stable-diffusion-webui-gfx803-fc18e2d48325afd902920ec519946c8b2a8019d0.tar.bz2
stable-diffusion-webui-gfx803-fc18e2d48325afd902920ec519946c8b2a8019d0.zip
Adds NSFW content filter option
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 92bf66f2..e777a965 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -19,6 +19,14 @@ import modules.face_restoration
import modules.images as images
import modules.styles
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from transformers import AutoFeatureExtractor
+
+# load safety model
+safety_model_id = "CompVis/stable-diffusion-safety-checker"
+safety_feature_extractor = None
+safety_checker = None
+
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
@@ -146,6 +154,28 @@ def fix_seed(p):
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
+def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+# check and replace nsfw content
+def check_safety(x_image):
+ global safety_feature_extractor, safety_checker
+ if safety_feature_extractor is None:
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
+ return x_checked_image, has_nsfw_concept
+
+
def process_images(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@@ -248,6 +278,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ if opts.filter_nsfw:
+ x_samples_ddim_numpy = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
+ x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
+ x_samples_ddim = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
+
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)