mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Bug fix] Fix img2img processor with safety checker (#3127)
Fix img2img processor with safety checker
This commit is contained in:
committed by
Daniel Gu
parent
c98e41dffe
commit
653b3c1a1a
@@ -85,7 +85,10 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
|
||||
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
images[idx] = np.zeros(images[idx].shape) # black image
|
||||
if torch.is_tensor(images) or torch.is_tensor(images[0]):
|
||||
images[idx] = torch.zeros_like(images[idx]) # black image
|
||||
else:
|
||||
images[idx] = np.zeros(images[idx].shape) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
logger.warning(
|
||||
|
||||
@@ -453,6 +453,20 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||
|
||||
def test_img2img_safety_checker_works(self):
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 20
|
||||
# make sure the safety checker is activated
|
||||
inputs["prompt"] = "naked, sex, porn"
|
||||
out = sd_pipe(**inputs)
|
||||
|
||||
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
|
||||
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user