1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Bug fix] Fix img2img processor with safety checker (#3127)

Fix img2img processor with safety checker
This commit is contained in:
Patrick von Platen
2023-04-17 11:53:04 +02:00
committed by Daniel Gu
parent c98e41dffe
commit 653b3c1a1a
2 changed files with 18 additions and 1 deletions

View File

@@ -85,7 +85,10 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
images[idx] = np.zeros(images[idx].shape) # black image
if torch.is_tensor(images) or torch.is_tensor(images[0]):
images[idx] = torch.zeros_like(images[idx]) # black image
else:
images[idx] = np.zeros(images[idx].shape) # black image
if any(has_nsfw_concepts):
logger.warning(

View File

@@ -453,6 +453,20 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
def test_img2img_safety_checker_works(self):
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 20
# make sure the safety checker is activated
inputs["prompt"] = "naked, sex, porn"
out = sd_pipe(**inputs)
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
@nightly
@require_torch_gpu