1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Add documentation for UniDiffuser and fix some typos/formatting in docstrings.

This commit is contained in:
Daniel Gu
2023-05-10 20:13:28 -07:00
parent abd6fca81e
commit ae7d549e0b
3 changed files with 111 additions and 10 deletions

View File

@@ -222,6 +222,8 @@
title: UnCLIP
- local: api/pipelines/latent_diffusion_uncond
title: Unconditional Latent Diffusion
- local: api/pipelines/unidiffuser
title: UniDiffuser
- local: api/pipelines/versatile_diffusion
title: Versatile Diffusion
- local: api/pipelines/vq_diffusion

View File

@@ -0,0 +1,99 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# UniDiffuser
The UniDiffuser model was proposed in [One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale](https://arxiv.org/abs/2303.06555) by Fan Bao, Shen Nie, Kaiwen Xue, Chongxuan Li, Shi Pu, Yaole Wang, Gang Yue, Yue Cao, Hang Su, Jun Zhu.
The abstract of the paper is the following:
*This paper proposes a unified diffusion framework (dubbed UniDiffuser) to fit all distributions relevant to a set of multi-modal data in one model. Our key insight is -- learning diffusion models for marginal, conditional, and joint distributions can be unified as predicting the noise in the perturbed data, where the perturbation levels (i.e. timesteps) can be different for different modalities. Inspired by the unified view, UniDiffuser learns all distributions simultaneously with a minimal modification to the original diffusion model -- perturbs data in all modalities instead of a single modality, inputs individual timesteps in different modalities, and predicts the noise of all modalities instead of a single modality. UniDiffuser is parameterized by a transformer for diffusion models to handle input types of different modalities. Implemented on large-scale paired image-text data, UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation by setting proper timesteps without additional overhead. In particular, UniDiffuser is able to produce perceptually realistic samples in all tasks and its quantitative results (e.g., the FID and CLIP score) are not only superior to existing general-purpose models but also comparable to the bespoken models (e.g., Stable Diffusion and DALL-E 2) in representative tasks (e.g., text-to-image generation).*
Resources:
* [Paper](https://arxiv.org/abs/2303.06555).
* [Original Code](https://github.com/thu-ml/unidiffuser).
Available Checkpoints are:
- *UniDiffuser-v0 (512x512 resolution)* [dg845/unidiffuser-diffusers-v0](https://huggingface.co/dg845/unidiffuser-diffusers-v0)
- *UniDiffuser-v1 (512x512 resolution)* [dg845/unidiffuser-diffusers-v1](https://huggingface.co/dg845/unidiffuser-diffusers)
## Available Pipelines:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*, *Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | |
## Usage Example
```python
import requests
import torch
from PIL import Image
from io import BytesIO
from diffusers import UniDiffuserPipeline
device = "cuda"
model_id_or_path = "dg845/unidiffuser-diffusers-test"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
# Joint image-text generation. The generation task is automatically inferred.
sample = pipe(num_inference_steps=20, guidance_scale=8.0)
image = sample.images[0]
text = sample.text[0]
image.save("unidiffuser_sample_joint_image.png")
print(text)
# The mode can be set manually. The following is equivalent to the above:
pipe.set_joint_mode()
sample2 = pipe(num_inference_steps=20, guidance_scale=8.0)
# Note that if you set the mode manually the pipeline will no longer attempt
# to automatically infer the mode. You can re-enable this with reset_mode().
pipe.reset_mode()
# Text-to-image generation.
prompt = "an elephant under the sea"
sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0)
t2i_image = sample.images[0]
t2i_image.save("unidiffuser_sample_text2img_image.png")
# Image-to-text generation.
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
response = requests.get(image_url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
i2t_text = sample.text[0]
print(text)
# Image variation can be performed with a image-to-text generation followed by a text-to-image generation:
sample = pipe(prompt=i2t_text, num_inference_steps=20, guidance_scale=8.0)
final_image = sample.images[0]
final_image.save("unidiffuser_image_variation_sample.png")
# Text variation can be performed with a text-to-image generation followed by a image-to-text generation:
sample = pipe(image=t2i_image, num_inference_steps=20, guidance_scale=8.0)
final_prompt = sample.text[0]
print(final_prompt)
```
## UniDiffuserPipeline
[[autodoc]] UniDiffuserPipeline
- all
- __call__
## ImageTextPipelineOutput
[[autodoc]] ImageTextPipelineOutput

View File

@@ -97,18 +97,18 @@ class UniDiffuserPipeline(DiffusionPipeline):
images as part of its image representation, along with the VAE latent representation.
image_processor ([`CLIPImageProcessor`]):
CLIP image processor of class
[`CLIPImageProcessor`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor),
[CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor),
used to preprocess the image before CLIP encoding it with `image_encoder`.
clip_tokenizer ([`CLIPTokenizer`]):
Tokenizer of class
[`CLIPTokenizer`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which
[CLIPTokenizer](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which
is used to tokenizer a prompt before encoding it with `text_encoder`.
text_decoder ([`UniDiffuserTextDecoder`]):
Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser
embedding.
text_tokenizer ([`GPT2Tokenizer`]):
Tokenizer of class
[`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which
is used along with the `text_decoder` to decode text for text generation.
unet ([`UniDiffuserModel`]):
UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a
@@ -173,7 +173,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
`torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` method called.
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
@@ -814,7 +814,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
text = einops.rearrange(text, "B L D -> B (L D)")
return torch.concat([img_vae, img_clip, text], dim=-1)
def get_noise_pred(
def _get_noise_pred(
self,
mode,
latents,
@@ -1096,18 +1096,18 @@ class UniDiffuserPipeline(DiffusionPipeline):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
guidance_scale (`float`, *optional*, defaults to 8.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. Note that the original [UniDiffuser
paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of guidance scale `w'`, which
paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of the guidance scale `w'`, which
satisfies `w = w' + 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). Used in text-conditioned image generation (`text2img` mode).
less than `1`). Used in text-conditioned image generation (`text2img`) mode.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and
`img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are
@@ -1159,7 +1159,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples: Returns:
Returns:
[`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`:
[`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images, and the second element is a list
@@ -1332,7 +1332,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
for i, t in enumerate(timesteps):
# predict the noise residual
# Also applies classifier-free guidance as described in the UniDiffuser paper
noise_pred = self.get_noise_pred(
noise_pred = self._get_noise_pred(
mode,
latents,
t,