1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Move safety detection to model call in Flax safety checker (#1023)

* Move safety detection to model call in Flax safety checker

* Update src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
This commit is contained in:
Jonatan Kłosko
2022-10-30 20:07:55 +01:00
committed by GitHub
parent 95414bd6bf
commit 8e4fd686e0
2 changed files with 37 additions and 77 deletions

View File

@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Dict, List, Optional, Union
@@ -97,9 +98,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
)
return text_input.input_ids
def _get_safety_scores(self, features, params):
special_cos_dist, cos_dist = self.safety_checker(features, params)
return (special_cos_dist, cos_dist)
def _get_has_nsfw_concepts(self, features, params):
has_nsfw_concepts = self.safety_checker(features, params)
return has_nsfw_concepts
def _run_safety_checker(self, images, safety_model_params, jit=False):
# safety_model_params should already be replicated when jit is True
@@ -108,20 +109,28 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
if jit:
features = shard(features)
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
special_cos_dist = unshard(special_cos_dist)
cos_dist = unshard(cos_dist)
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
has_nsfw_concepts = unshard(has_nsfw_concepts)
safety_model_params = unreplicate(safety_model_params)
else:
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
images, has_nsfw = self.safety_checker.filtered_with_scores(
special_cos_dist,
cos_dist,
images,
safety_model_params,
)
return images, has_nsfw
images_was_copied = False
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
if not images_was_copied:
images_was_copied = True
images = images.copy()
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
if any(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts
def _generate(
self,
@@ -310,8 +319,8 @@ def _p_generate(
@partial(jax.pmap, static_broadcasted_argnums=(0,))
def _p_get_safety_scores(pipe, features, params):
return pipe._get_safety_scores(features, params)
def _p_get_has_nsfw_concepts(pipe, features, params):
return pipe._get_has_nsfw_concepts(features, params)
def unshard(x: jnp.ndarray):

View File

@@ -1,8 +1,5 @@
import warnings
from typing import Optional, Tuple
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
@@ -39,56 +36,22 @@ class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
return special_cos_dist, cos_dist
def filtered_with_scores(self, special_cos_dist, cos_dist, images):
batch_size = special_cos_dist.shape[0]
special_cos_dist = np.asarray(special_cos_dist)
cos_dist = np.asarray(cos_dist)
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign image inputs
adjustment = 0.0
result = []
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment
special_scores = jnp.round(special_scores, 3)
is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True)
# Use a lower threshold if an image has any special care concept
special_adjustment = is_special_care * 0.01
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign image inputs
adjustment = 0.0
concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment
concept_scores = jnp.round(concept_scores, 3)
has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1)
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
images_was_copied = False
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
if not images_was_copied:
images_was_copied = True
images = images.copy()
images[idx] = np.zeros(images[idx].shape) # black image
if any(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts
return has_nsfw_concepts
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
@@ -133,15 +96,3 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
jnp.array(clip_input, dtype=jnp.float32),
rngs={},
)
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
return self.module.apply(
{"params": params or self.params},
special_cos_dist,
cos_dist,
images,
method=_filtered_with_scores,
)