mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Move safety detection to model call in Flax safety checker (#1023)
* Move safety detection to model call in Flax safety checker * Update src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@@ -97,9 +98,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
)
|
||||
return text_input.input_ids
|
||||
|
||||
def _get_safety_scores(self, features, params):
|
||||
special_cos_dist, cos_dist = self.safety_checker(features, params)
|
||||
return (special_cos_dist, cos_dist)
|
||||
def _get_has_nsfw_concepts(self, features, params):
|
||||
has_nsfw_concepts = self.safety_checker(features, params)
|
||||
return has_nsfw_concepts
|
||||
|
||||
def _run_safety_checker(self, images, safety_model_params, jit=False):
|
||||
# safety_model_params should already be replicated when jit is True
|
||||
@@ -108,20 +109,28 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
|
||||
if jit:
|
||||
features = shard(features)
|
||||
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
|
||||
special_cos_dist = unshard(special_cos_dist)
|
||||
cos_dist = unshard(cos_dist)
|
||||
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
|
||||
has_nsfw_concepts = unshard(has_nsfw_concepts)
|
||||
safety_model_params = unreplicate(safety_model_params)
|
||||
else:
|
||||
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
|
||||
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
|
||||
|
||||
images, has_nsfw = self.safety_checker.filtered_with_scores(
|
||||
special_cos_dist,
|
||||
cos_dist,
|
||||
images,
|
||||
safety_model_params,
|
||||
)
|
||||
return images, has_nsfw
|
||||
images_was_copied = False
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
if not images_was_copied:
|
||||
images_was_copied = True
|
||||
images = images.copy()
|
||||
|
||||
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
warnings.warn(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned"
|
||||
" instead. Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -310,8 +319,8 @@ def _p_generate(
|
||||
|
||||
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0,))
|
||||
def _p_get_safety_scores(pipe, features, params):
|
||||
return pipe._get_safety_scores(features, params)
|
||||
def _p_get_has_nsfw_concepts(pipe, features, params):
|
||||
return pipe._get_has_nsfw_concepts(features, params)
|
||||
|
||||
|
||||
def unshard(x: jnp.ndarray):
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import linen as nn
|
||||
@@ -39,56 +36,22 @@ class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
|
||||
|
||||
special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
|
||||
cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
|
||||
return special_cos_dist, cos_dist
|
||||
|
||||
def filtered_with_scores(self, special_cos_dist, cos_dist, images):
|
||||
batch_size = special_cos_dist.shape[0]
|
||||
special_cos_dist = np.asarray(special_cos_dist)
|
||||
cos_dist = np.asarray(cos_dist)
|
||||
# increase this value to create a stronger `nfsw` filter
|
||||
# at the cost of increasing the possibility of filtering benign image inputs
|
||||
adjustment = 0.0
|
||||
|
||||
result = []
|
||||
for i in range(batch_size):
|
||||
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
|
||||
special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment
|
||||
special_scores = jnp.round(special_scores, 3)
|
||||
is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True)
|
||||
# Use a lower threshold if an image has any special care concept
|
||||
special_adjustment = is_special_care * 0.01
|
||||
|
||||
# increase this value to create a stronger `nfsw` filter
|
||||
# at the cost of increasing the possibility of filtering benign image inputs
|
||||
adjustment = 0.0
|
||||
concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment
|
||||
concept_scores = jnp.round(concept_scores, 3)
|
||||
has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1)
|
||||
|
||||
for concept_idx in range(len(special_cos_dist[0])):
|
||||
concept_cos = special_cos_dist[i][concept_idx]
|
||||
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
|
||||
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["special_scores"][concept_idx] > 0:
|
||||
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
|
||||
adjustment = 0.01
|
||||
|
||||
for concept_idx in range(len(cos_dist[0])):
|
||||
concept_cos = cos_dist[i][concept_idx]
|
||||
concept_threshold = self.concept_embeds_weights[concept_idx].item()
|
||||
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["concept_scores"][concept_idx] > 0:
|
||||
result_img["bad_concepts"].append(concept_idx)
|
||||
|
||||
result.append(result_img)
|
||||
|
||||
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
|
||||
|
||||
images_was_copied = False
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
if not images_was_copied:
|
||||
images_was_copied = True
|
||||
images = images.copy()
|
||||
|
||||
images[idx] = np.zeros(images[idx].shape) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
warnings.warn(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned"
|
||||
" instead. Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
return has_nsfw_concepts
|
||||
|
||||
|
||||
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
|
||||
@@ -133,15 +96,3 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
|
||||
jnp.array(clip_input, dtype=jnp.float32),
|
||||
rngs={},
|
||||
)
|
||||
|
||||
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
|
||||
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
|
||||
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
special_cos_dist,
|
||||
cos_dist,
|
||||
images,
|
||||
method=_filtered_with_scores,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user