diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index f23c2d36dd..9fb2be02ac 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -341,21 +341,20 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): def run_safety_checker(self, image, device, dtype, enable_safety_guidance): if self.safety_checker is not None: + images = image.copy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - flagged_images = None + flagged_images = np.zeros((2, *image.shape[1:])) if any(has_nsfw_concept): logger.warning( - "Potential NSFW content was detected in one or more images. A black image will be returned" - " instead." - f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} " + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" ) - flagged_images = np.zeros((2, *image.shape[1:])) for idx, has_nsfw_concept in enumerate(has_nsfw_concept): if has_nsfw_concept: - flagged_images[idx] = image[idx] + flagged_images[idx] = images[idx] image[idx] = np.zeros(image[idx].shape) # black image else: has_nsfw_concept = None