1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add Safe Stable Diffusion Pipeline (#1244)

* Add pipeline_stable_diffusion_safe.py to pipelines

* Fix repository consistency

Ran make fix-copies after adding new pipline

* Add Paper/Equation reference for parameters to doc string

* Ensure code style and quality

* Perform code refactoring

* Fix copies inherited from merge with huggingface/main

* Add docs

* Fix code style

* Fix errors in documentation

* Fix refactoring error

* remove debugging print statement

* added Safe Latent Diffusion tests

* Fix style

* Fix style

* Add pre-defined safety configurations

* Fix line-break

* fix some tests

* finish

* Change safety checker

* Add missing safety_checker.py file

* Remove unused imports

Co-authored-by: PatrickSchrML <patrick_schramowski@hotmail.de>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Manuel Brack
2022-11-22 11:51:30 +01:00
committed by GitHub
parent 182eb959e5
commit e50c25d808
12 changed files with 1450 additions and 1 deletions

View File

@@ -106,6 +106,8 @@
title: "Score SDE VE"
- local: api/pipelines/stable_diffusion
title: "Stable Diffusion"
- local: api/pipelines/stable_diffusion_safe
title: "Safe Stable Diffusion"
- local: api/pipelines/stochastic_karras_ve
title: "Stochastic Karras VE"
- local: api/pipelines/dance_diffusion

View File

@@ -58,7 +58,8 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |

View File

@@ -0,0 +1,90 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Safe Stable Diffusion
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://arxiv.org/abs/2211.05105) and mitigates the well known issue that models like Stable Diffusion that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, or otherwise offensive content.
Safe Stable Diffusion is an extension to the Stable Diffusion that drastically reduces content like this.
The abstract of the paper is the following:
*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.*
*Overview*:
| Pipeline | Tasks | Colab | Demo
|---|---|:---:|:---:|
| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | -
## Tips
- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./api/pipelines/stable_diffusion).
### Run Safe Stable Diffusion
Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation).
### Interacting with the Safety Concept
To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`]
```python
>>> from diffusers import StableDiffusionPipelineSafe
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.safety_concept
```
For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`].
### Using pre-defined safety configurations
You may use the 4 configurations defined in the [Safe Latent Diffusion paper](https://arxiv.org/abs/2211.05105) as follows:
```python
>>> from diffusers import StableDiffusionPipelineSafe
>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
```
The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`.
### How to load and use different schedulers.
The safe stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import StableDiffusionPipelineSafe, EulerDiscreteScheduler
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("AIML-TUDA/stable-diffusion-safe", subfolder="scheduler")
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained(
... "AIML-TUDA/stable-diffusion-safe", scheduler=euler_scheduler
... )
```
## StableDiffusionSafePipelineOutput
[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput
## StableDiffusionPipelineSafe
[[autodoc]] StableDiffusionPipelineSafe
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@@ -48,6 +48,7 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |

View File

@@ -73,6 +73,7 @@ if is_torch_available() and is_transformers_available():
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
VQDiffusionPipeline,
)
else:

View File

@@ -24,6 +24,7 @@ if is_torch_available() and is_transformers_available():
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .vq_diffusion import VQDiffusionPipeline
if is_transformers_available() and is_onnx_available():

View File

@@ -0,0 +1,72 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_torch_available, is_transformers_available
@dataclass
class SafetyConfig(object):
WEAK = {
"sld_warmup_steps": 15,
"sld_guidance_scale": 20,
"sld_threshold": 0.0,
"sld_momentum_scale": 0.0,
"sld_mom_beta": 0.0,
}
MEDIUM = {
"sld_warmup_steps": 10,
"sld_guidance_scale": 1000,
"sld_threshold": 0.01,
"sld_momentum_scale": 0.3,
"sld_mom_beta": 0.4,
}
STRONG = {
"sld_warmup_steps": 7,
"sld_guidance_scale": 2000,
"sld_threshold": 0.025,
"sld_momentum_scale": 0.5,
"sld_mom_beta": 0.7,
}
MAX = {
"sld_warmup_steps": 0,
"sld_guidance_scale": 5000,
"sld_threshold": 1.0,
"sld_momentum_scale": 0.5,
"sld_mom_beta": 0.7,
}
@dataclass
class StableDiffusionSafePipelineOutput(BaseOutput):
"""
Output class for Safe Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, or `None` if safety checking could not be performed.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work"
(nsfw) content, or `None` if no safety check was performed or no images were flagged.
applied_safety_concept (`str`)
The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
applied_safety_concept: Optional[str]
if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe
from .safety_checker import SafeStableDiffusionSafetyChecker

View File

@@ -0,0 +1,721 @@
import inspect
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, is_accelerate_available, logging
from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionPipelineSafe(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Safe Latent Diffusion.
The implementation is based on the [`StableDiffusionPipeline`]
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
],
safety_checker: SafeStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
safety_concept: Optional[str] = (
"an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity,"
" bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child"
" abuse, brutality, cruelty"
)
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self._safety_text_concept = safety_concept
@property
def safety_concept(self):
r"""
Getter method for the safety concept used with SLD
Returns:
`str`: The text describing the safety concept
"""
return self._safety_text_concept
@safety_concept.setter
def safety_concept(self, concept):
r"""
Setter method for the safety concept used with SLD
Args:
concept (`str`):
The text of the new safety concept
"""
self._safety_text_concept = concept
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
enable_safety_guidance,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# Encode the safety concept text
if enable_safety_guidance:
safety_concept_input = self.tokenizer(
[self._safety_text_concept],
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0]
# duplicate safety embeddings for each generation per prompt, using mps friendly method
seq_len = safety_embeddings.shape[1]
safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1)
safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance + sld, we need to do three forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing three forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings])
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
flagged_images = None
if any(has_nsfw_concept):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead."
f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} "
)
flagged_images = np.zeros((2, *image.shape[1:]))
for idx, has_nsfw_concept in enumerate(has_nsfw_concept):
if has_nsfw_concept:
flagged_images[idx] = image[idx]
image[idx] = np.zeros(image[idx].shape) # black image
else:
has_nsfw_concept = None
flagged_images = None
return image, has_nsfw_concept, flagged_images
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def perform_safety_guidance(
self,
enable_safety_guidance,
safety_momentum,
noise_guidance,
noise_pred_out,
i,
sld_guidance_scale,
sld_warmup_steps,
sld_threshold,
sld_momentum_scale,
sld_mom_beta,
):
# Perform SLD guidance
if enable_safety_guidance:
if safety_momentum is None:
safety_momentum = torch.zeros_like(noise_guidance)
noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1]
noise_pred_safety_concept = noise_pred_out[2]
# Equation 6
scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0)
# Equation 6
safety_concept_scale = torch.where(
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
)
# Equation 4
noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale)
# Equation 7
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
# Equation 8
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
if i >= sld_warmup_steps: # Warmup
# Equation 3
noise_guidance = noise_guidance - noise_guidance_safety
return noise_guidance, safety_momentum
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
sld_guidance_scale: Optional[float] = 1000,
sld_warmup_steps: Optional[int] = 10,
sld_threshold: Optional[float] = 0.01,
sld_momentum_scale: Optional[float] = 0.3,
sld_mom_beta: Optional[float] = 0.4,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
sld_guidance_scale (`float`, *optional*, defaults to 1000):
Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
`sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be
disabled.
sld_warmup_steps (`int`, *optional*, defaults to 10):
Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than
`sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent
Diffusion](https://arxiv.org/abs/2211.05105).
sld_threshold (`float`, *optional*, defaults to 0.01):
Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold`
is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
sld_momentum_scale (`float`, *optional*, defaults to 0.3):
Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0
momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller
than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent
Diffusion](https://arxiv.org/abs/2211.05105).
sld_mom_beta (`float`, *optional*, defaults to 0.4):
Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous
momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller
than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent
Diffusion](https://arxiv.org/abs/2211.05105).
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance
if not enable_safety_guidance:
warnings.warn("Safety checker disabled!")
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
safety_momentum = None
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * (3 if enable_safety_guidance else 2)) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
# default classifier free guidance
noise_guidance = noise_pred_text - noise_pred_uncond
# Perform SLD guidance
if enable_safety_guidance:
if safety_momentum is None:
safety_momentum = torch.zeros_like(noise_guidance)
noise_pred_safety_concept = noise_pred_out[2]
# Equation 6
scale = torch.clamp(
torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
)
# Equation 6
safety_concept_scale = torch.where(
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
)
# Equation 4
noise_guidance_safety = torch.mul(
(noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
)
# Equation 7
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
# Equation 8
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
if i >= sld_warmup_steps: # Warmup
# Equation 3
noise_guidance = noise_guidance - noise_guidance_safety
noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept, flagged_images = self.run_safety_checker(
image, device, text_embeddings.dtype, enable_safety_guidance
)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if flagged_images is not None:
flagged_images = self.numpy_to_pil(flagged_images)
if not return_dict:
return (
image,
has_nsfw_concept,
self._safety_text_concept if enable_safety_guidance else None,
flagged_images,
)
return StableDiffusionSafePipelineOutput(
images=image,
nsfw_content_detected=has_nsfw_concept,
applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None,
unsafe_images=flagged_images,
)

View File

@@ -0,0 +1,110 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
from ...utils import logging
logger = logging.get_logger(__name__)
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class SafeStableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
return images, has_nsfw_concepts
@torch.no_grad()
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
# increase this value to create a stronger `nsfw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
# special_scores = special_scores.round(decimals=3)
special_care = torch.any(special_scores > 0, dim=1)
special_adjustment = special_care * 0.01
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
# concept_scores = concept_scores.round(decimals=3)
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
return images, has_nsfw_concepts

View File

@@ -124,6 +124,21 @@ class StableDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionPipelineSafe(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,435 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import tempfile
import unittest
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
from diffusers.utils import floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property
def dummy_image(self):
batch_size = 1
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image
@property
def dummy_cond_unet(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
return model
@property
def dummy_vae(self):
torch.manual_seed(0)
model = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
return model
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config)
@property
def dummy_extractor(self):
def extract(*args, **kwargs):
class Out:
def __init__(self):
self.pixel_values = torch.ones([0])
def to(self, device):
self.pixel_values.to(device)
return self
return Out()
return extract
def test_safe_diffusion_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_no_safety_checker(self):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
)
assert isinstance(pipe, StableDiffusionPipeline)
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
assert pipe.safety_checker is None
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
# check that there's no error when saving a pipeline with one of the models being None
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
# sanity check that the pipeline still works
assert pipe.safety_checker is None
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_fp16(self):
"""Test that stable diffusion works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
assert image.shape == (1, 128, 128, 3)
@slow
@require_torch_gpu
class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_harm_safe_stable_diffusion(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = (
"portrait of girl with smokey eyes makeup in abandoned hotel, grange clothes, redshift, wide high angle"
" coloured polaroid photograph with flash, kodak film, hyper real, stunning moody cinematography, with"
" anamorphic lenses, by maripol, fallen angels by wong kar - wai, style of suspiria and neon demon and"
" children from bahnhof zoo, detailed "
)
seed = 4003660346
guidance_scale = 7
# without safety guidance (sld_guidance_scale = 0)
generator = torch.Generator(device=torch_device).manual_seed(seed)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=guidance_scale,
num_inference_steps=50,
output_type="np",
width=512,
height=512,
sld_guidance_scale=0,
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176]
assert image.shape == (1, 512, 512, 3)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
# without safety guidance (strong configuration)
generator = torch.Generator(device=torch_device).manual_seed(seed)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=guidance_scale,
num_inference_steps=50,
output_type="np",
width=512,
height=512,
sld_guidance_scale=2000,
sld_warmup_steps=7,
sld_threshold=0.025,
sld_momentum_scale=0.5,
sld_mom_beta=0.7,
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
assert image.shape == (1, 512, 512, 3)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_nudity_safe_stable_diffusion(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "padme amidala taking a bath artwork, safe for work, no nudity"
seed = 2734971755
guidance_scale = 7
generator = torch.Generator(device=torch_device).manual_seed(seed)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=guidance_scale,
num_inference_steps=50,
output_type="np",
width=512,
height=512,
sld_guidance_scale=0,
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297]
assert image.shape == (1, 512, 512, 3)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
generator = torch.Generator(device=torch_device).manual_seed(seed)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=guidance_scale,
num_inference_steps=50,
output_type="np",
width=512,
height=512,
sld_guidance_scale=2000,
sld_warmup_steps=7,
sld_threshold=0.025,
sld_momentum_scale=0.5,
sld_mom_beta=0.7,
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443]
assert image.shape == (1, 512, 512, 3)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_nudity_safetychecker_safe_stable_diffusion(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = (
"the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c."
" leyendecker"
)
seed = 1044355234
guidance_scale = 12
generator = torch.Generator(device=torch_device).manual_seed(seed)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=guidance_scale,
num_inference_steps=50,
output_type="np",
width=512,
height=512,
sld_guidance_scale=0,
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
assert image.shape == (1, 512, 512, 3)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7
generator = torch.Generator(device=torch_device).manual_seed(seed)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=guidance_scale,
num_inference_steps=50,
output_type="np",
width=512,
height=512,
sld_guidance_scale=2000,
sld_warmup_steps=7,
sld_threshold=0.025,
sld_momentum_scale=0.5,
sld_mom_beta=0.7,
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561])
assert image.shape == (1, 512, 512, 3)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2