1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix: flagged_images implementation (#1947)

Flagged images would be set to the blank image instead of the original image that contained the NSF concept for optional viewing.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Justin Merrell
2023-02-03 12:02:56 -05:00
committed by GitHub
parent 2f9a70aa85
commit 948022e1e8

View File

@@ -341,21 +341,20 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
if self.safety_checker is not None:
images = image.copy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
flagged_images = None
flagged_images = np.zeros((2, *image.shape[1:]))
if any(has_nsfw_concept):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead."
f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} "
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}"
)
flagged_images = np.zeros((2, *image.shape[1:]))
for idx, has_nsfw_concept in enumerate(has_nsfw_concept):
if has_nsfw_concept:
flagged_images[idx] = image[idx]
flagged_images[idx] = images[idx]
image[idx] = np.zeros(image[idx].shape) # black image
else:
has_nsfw_concept = None