1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add Shap-E (#3742)

* refactor prior_transformer

adding conversion script

add pipeline

add step_index from pipeline, + remove permute

add zero pad token

remove copy from statement for betas_for_alpha_bar function

* add

* add

* update conversion script for renderer model

* refactor camera a little bit

* clean up

* style

* fix copies

* Update src/diffusers/schedulers/scheduling_heun_discrete.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* alpha_transform_type

* remove step_index argument

* remove get_sigmas_karras

* remove _yiyi_sigma_to_t

* move the rescale prompt_embeds from prior_transformer to pipeline

* replace baddbmm with einsum to match origial repo

* Revert "replace baddbmm with einsum to match origial repo"

This reverts commit 3f6b435d65.

* add step_index to scale_model_input

* Revert "move the rescale prompt_embeds from prior_transformer to pipeline"

This reverts commit 5b5a8e6be9.

* move rescale from prior_transformer to pipeline

* correct step_index in scale_model_input

* remove print lines

* refactor prior - reduce arguments

* make style

* add prior_image

* arg embedding_proj_norm -> norm_embedding_proj

* add pre-norm for proj_embedding

* move rescale prompt from pipeline to _encode_prompt

* add img2img pipeline

* style

* copies

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/prior_transformer.py

add arg: encoder_hid_proj

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/prior_transformer.py

add new config: norm_in_type

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/prior_transformer.py

add new config: added_emb_type

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/prior_transformer.py

rename out_dim -> clip_embed_dim

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/prior_transformer.py

rename config: out_dim -> clip_embed_dim

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/prior_transformer.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* finish refactor prior_tranformer

* make style

* refactor renderer

* fix

* make style

* refactor img2img

* remove params_proj

* add test

* add upcast_softmax to prior_transformer

* enable num_images_per_prompt, add save_gif utility

* add

* add fast test

* make style

* add slow test

* style

* add test for img2img

* refactor

* enable batching

* style

* refactor scheduler

* update test

* style

* attempt to solve batch related tests timeout

* add doc

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* hardcode rendering related config

* update betas_for_alpha_bar on ddpm_scheduler

* fix copies

* fix

* export_to_gif

* style

* second attempt to speed up batching tests

* add doc page to index

* Remove intermediate clipping

* 3rd attempt to speed up batching tests

* Remvoe time index

* simplify scheduler

* Fix more

* Fix more

* fix more

* make style

* fix schedulers

* fix some more tests

* finish

* add one more test

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* style

* apply feedbacks

* style

* fix copies

* add one example

* style

* add example for img2img

* fix doc

* fix more doc strings

* size -> frame_size

* style

* update doc

* style

* fix on doc

* update repo name

* improve the usage example in shap-e img2img

* add usage examples in the shap-e docs.

* consolidate examples.

* minor fix.

* update doc

* Apply suggestions from code review

* Apply suggestions from code review

* remove upcast

* Make sure background is white

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

* Apply suggestions from code review

* Finish

* Apply suggestions from code review

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py

* Make style

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
YiYi Xu
2023-07-06 03:20:42 -10:00
committed by GitHub
parent 746215670a
commit 45f6d52b10
37 changed files with 3534 additions and 116 deletions

View File

@@ -226,6 +226,8 @@
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/spectrogram_diffusion
title: Spectrogram Diffusion
- sections:

View File

@@ -0,0 +1,139 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Shap-E
## Overview
The Shap-E model was proposed in [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463) by Alex Nichol and Heewon Jun from [OpenAI](https://github.com/openai).
The abstract of the paper is the following:
*We present Shap-E, a conditional generative model for 3D assets. Unlike recent work on 3D generative models which produce a single output representation, Shap-E directly generates the parameters of implicit functions that can be rendered as both textured meshes and neural radiance fields. We train Shap-E in two stages: first, we train an encoder that deterministically maps 3D assets into the parameters of an implicit function; second, we train a conditional diffusion model on outputs of the encoder. When trained on a large dataset of paired 3D and text data, our resulting models are capable of generating complex and diverse 3D assets in a matter of seconds. When compared to Point-E, an explicit generative model over point clouds, Shap-E converges faster and reaches comparable or better sample quality despite modeling a higher-dimensional, multi-representation output space.*
The original codebase can be found [here](https://github.com/openai/shap-e).
## Available Pipelines:
| Pipeline | Tasks |
|---|---|
| [pipeline_shap_e.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e.py) | *Text-to-Image Generation* |
| [pipeline_shap_e_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py) | *Image-to-Image Generation* |
## Available checkpoints
* [`openai/shap-e`](https://huggingface.co/openai/shap-e)
* [`openai/shap-e-img2img`](https://huggingface.co/openai/shap-e-img2img)
## Usage Examples
In the following, we will walk you through some examples of how to use Shap-E pipelines to create 3D objects in gif format.
### Text-to-3D image generation
We can use [`ShapEPipeline`] to create 3D object based on a text prompt. In this example, we will make a birthday cupcake for :firecracker: diffusers library's 1 year birthday. The workflow to use the Shap-E text-to-image pipeline is same as how you would use other text-to-image pipelines in diffusers.
```python
import torch
from diffusers import DiffusionPipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo = "openai/shap-e"
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
pipe = pipe.to(device)
guidance_scale = 15.0
prompt = ["A firecracker", "A birthday cupcake"]
images = pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=64,
frame_size=256,
).images
```
The output of [`ShapEPipeline`] is a list of lists of images frames. Each list of frames can be used to create a 3D object. Let's use the `export_to_gif` utility function in diffusers to make a 3D cupcake!
```python
from diffusers.utils import export_to_gif
export_to_gif(images[0], "firecracker_3d.gif")
export_to_gif(images[1], "cake_3d.gif")
```
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/firecracker_out.gif)
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/cake_out.gif)
### Image-to-Image generation
You can use [`ShapEImg2ImgPipeline`] along with other text-to-image pipelines in diffusers and turn your 2D generation into 3D.
In this example, We will first genrate a cheeseburger with a simple prompt "A cheeseburger, white background"
```python
from diffusers import DiffusionPipeline
import torch
pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16)
pipe_prior.to("cuda")
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
t2i_pipe.to("cuda")
prompt = "A cheeseburger, white background"
image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
image = t2i_pipe(
prompt,
image_embeds=image_embeds,
negative_image_embeds=negative_image_embeds,
).images[0]
image.save("burger.png")
```
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_in.png)
we will then use the Shap-E image-to-image pipeline to turn it into a 3D cheeseburger :)
```python
from PIL import Image
from diffusers.utils import export_to_gif
repo = "openai/shap-e-img2img"
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
guidance_scale = 3.0
image = Image.open("burger.png").resize((256, 256))
images = pipe(
image,
guidance_scale=guidance_scale,
num_inference_steps=64,
frame_size=256,
).images
gif_path = export_to_gif(images[0], "burger_3d.gif")
```
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_out.gif)
## ShapEPipeline
[[autodoc]] ShapEPipeline
- all
- __call__
## ShapEImg2ImgPipeline
[[autodoc]] ShapEImg2ImgPipeline
- all
- __call__

View File

@@ -0,0 +1,594 @@
import argparse
import tempfile
import torch
from accelerate import load_checkpoint_and_dispatch
from diffusers.models.prior_transformer import PriorTransformer
from diffusers.pipelines.shap_e import ShapERenderer
"""
Example - From the diffusers root directory:
Download weights:
```sh
$ wget "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt"
```
Convert the model:
```sh
$ python scripts/convert_shap_e_to_diffusers.py \
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
--prior_image_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/image_cond.pt \
--transmitter_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\
--dump_path /home/yiyi_huggingface_co/model_repo/shap-e/renderer\
--debug renderer
```
"""
# prior
PRIOR_ORIGINAL_PREFIX = "wrapped"
PRIOR_CONFIG = {
"num_attention_heads": 16,
"attention_head_dim": 1024 // 16,
"num_layers": 24,
"embedding_dim": 1024,
"num_embeddings": 1024,
"additional_embeddings": 0,
"time_embed_act_fn": "gelu",
"norm_in_type": "layer",
"encoder_hid_proj_type": None,
"added_emb_type": None,
"time_embed_dim": 1024 * 4,
"embedding_proj_dim": 768,
"clip_embed_dim": 1024 * 2,
}
def prior_model_from_original_config():
model = PriorTransformer(**PRIOR_CONFIG)
return model
def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
# <original>.time_embed.c_fc -> <diffusers>.time_embedding.linear_1
diffusers_checkpoint.update(
{
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.weight"],
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.bias"],
}
)
# <original>.time_embed.c_proj -> <diffusers>.time_embedding.linear_2
diffusers_checkpoint.update(
{
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.weight"],
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.bias"],
}
)
# <original>.input_proj -> <diffusers>.proj_in
diffusers_checkpoint.update(
{
"proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.weight"],
"proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.bias"],
}
)
# <original>.clip_emb -> <diffusers>.embedding_proj
diffusers_checkpoint.update(
{
"embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.weight"],
"embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.bias"],
}
)
# <original>.pos_emb -> <diffusers>.positional_embedding
diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.pos_emb"][None, :]})
# <original>.ln_pre -> <diffusers>.norm_in
diffusers_checkpoint.update(
{
"norm_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.weight"],
"norm_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.bias"],
}
)
# <original>.backbone.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
for idx in range(len(model.transformer_blocks)):
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.backbone.resblocks.{idx}"
# <original>.attn -> <diffusers>.attn1
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
original_attention_prefix = f"{original_transformer_prefix}.attn"
diffusers_checkpoint.update(
prior_attention_to_diffusers(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
original_attention_prefix=original_attention_prefix,
attention_head_dim=model.attention_head_dim,
)
)
# <original>.mlp -> <diffusers>.ff
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
original_ff_prefix = f"{original_transformer_prefix}.mlp"
diffusers_checkpoint.update(
prior_ff_to_diffusers(
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
)
)
# <original>.ln_1 -> <diffusers>.norm1
diffusers_checkpoint.update(
{
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
f"{original_transformer_prefix}.ln_1.weight"
],
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
}
)
# <original>.ln_2 -> <diffusers>.norm3
diffusers_checkpoint.update(
{
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
f"{original_transformer_prefix}.ln_2.weight"
],
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
}
)
# <original>.ln_post -> <diffusers>.norm_out
diffusers_checkpoint.update(
{
"norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.weight"],
"norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.bias"],
}
)
# <original>.output_proj -> <diffusers>.proj_to_clip_embeddings
diffusers_checkpoint.update(
{
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.weight"],
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.bias"],
}
)
return diffusers_checkpoint
def prior_attention_to_diffusers(
checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
):
diffusers_checkpoint = {}
# <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
split=3,
chunk_size=attention_head_dim,
)
diffusers_checkpoint.update(
{
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
}
)
# <original>.c_proj -> <diffusers>.to_out.0
diffusers_checkpoint.update(
{
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
}
)
return diffusers_checkpoint
def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
diffusers_checkpoint = {
# <original>.c_fc -> <diffusers>.net.0.proj
f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
# <original>.c_proj -> <diffusers>.net.2
f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
}
return diffusers_checkpoint
# done prior
# prior_image (only slightly different from prior)
PRIOR_IMAGE_ORIGINAL_PREFIX = "wrapped"
# Uses default arguments
PRIOR_IMAGE_CONFIG = {
"num_attention_heads": 8,
"attention_head_dim": 1024 // 8,
"num_layers": 24,
"embedding_dim": 1024,
"num_embeddings": 1024,
"additional_embeddings": 0,
"time_embed_act_fn": "gelu",
"norm_in_type": "layer",
"embedding_proj_norm_type": "layer",
"encoder_hid_proj_type": None,
"added_emb_type": None,
"time_embed_dim": 1024 * 4,
"embedding_proj_dim": 1024,
"clip_embed_dim": 1024 * 2,
}
def prior_image_model_from_original_config():
model = PriorTransformer(**PRIOR_IMAGE_CONFIG)
return model
def prior_image_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
# <original>.time_embed.c_fc -> <diffusers>.time_embedding.linear_1
diffusers_checkpoint.update(
{
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.weight"],
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.bias"],
}
)
# <original>.time_embed.c_proj -> <diffusers>.time_embedding.linear_2
diffusers_checkpoint.update(
{
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.weight"],
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.bias"],
}
)
# <original>.input_proj -> <diffusers>.proj_in
diffusers_checkpoint.update(
{
"proj_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.weight"],
"proj_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.bias"],
}
)
# <original>.clip_embed.0 -> <diffusers>.embedding_proj_norm
diffusers_checkpoint.update(
{
"embedding_proj_norm.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.weight"],
"embedding_proj_norm.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.bias"],
}
)
# <original>..clip_embed.1 -> <diffusers>.embedding_proj
diffusers_checkpoint.update(
{
"embedding_proj.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.weight"],
"embedding_proj.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.bias"],
}
)
# <original>.pos_emb -> <diffusers>.positional_embedding
diffusers_checkpoint.update(
{"positional_embedding": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.pos_emb"][None, :]}
)
# <original>.ln_pre -> <diffusers>.norm_in
diffusers_checkpoint.update(
{
"norm_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.weight"],
"norm_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.bias"],
}
)
# <original>.backbone.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
for idx in range(len(model.transformer_blocks)):
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
original_transformer_prefix = f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.backbone.resblocks.{idx}"
# <original>.attn -> <diffusers>.attn1
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
original_attention_prefix = f"{original_transformer_prefix}.attn"
diffusers_checkpoint.update(
prior_attention_to_diffusers(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
original_attention_prefix=original_attention_prefix,
attention_head_dim=model.attention_head_dim,
)
)
# <original>.mlp -> <diffusers>.ff
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
original_ff_prefix = f"{original_transformer_prefix}.mlp"
diffusers_checkpoint.update(
prior_ff_to_diffusers(
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
)
)
# <original>.ln_1 -> <diffusers>.norm1
diffusers_checkpoint.update(
{
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
f"{original_transformer_prefix}.ln_1.weight"
],
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
}
)
# <original>.ln_2 -> <diffusers>.norm3
diffusers_checkpoint.update(
{
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
f"{original_transformer_prefix}.ln_2.weight"
],
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
}
)
# <original>.ln_post -> <diffusers>.norm_out
diffusers_checkpoint.update(
{
"norm_out.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.weight"],
"norm_out.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.bias"],
}
)
# <original>.output_proj -> <diffusers>.proj_to_clip_embeddings
diffusers_checkpoint.update(
{
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.weight"],
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.bias"],
}
)
return diffusers_checkpoint
# done prior_image
# renderer
RENDERER_CONFIG = {}
def renderer_model_from_original_config():
model = ShapERenderer(**RENDERER_CONFIG)
return model
RENDERER_MLP_ORIGINAL_PREFIX = "renderer.nerstf"
RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX = "encoder.params_proj"
def renderer_model_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
diffusers_checkpoint.update(
{f"mlp.{k}": checkpoint[f"{RENDERER_MLP_ORIGINAL_PREFIX}.{k}"] for k in model.mlp.state_dict().keys()}
)
diffusers_checkpoint.update(
{
f"params_proj.{k}": checkpoint[f"{RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX}.{k}"]
for k in model.params_proj.state_dict().keys()
}
)
diffusers_checkpoint.update({"void.background": torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)})
return diffusers_checkpoint
# done renderer
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
def split_attentions(*, weight, bias, split, chunk_size):
weights = [None] * split
biases = [None] * split
weights_biases_idx = 0
for starting_row_index in range(0, weight.shape[0], chunk_size):
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
weight_rows = weight[row_indices, :]
bias_rows = bias[row_indices]
if weights[weights_biases_idx] is None:
assert weights[weights_biases_idx] is None
weights[weights_biases_idx] = weight_rows
biases[weights_biases_idx] = bias_rows
else:
assert weights[weights_biases_idx] is not None
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
weights_biases_idx = (weights_biases_idx + 1) % split
return weights, biases
# done unet utils
# Driver functions
def prior(*, args, checkpoint_map_location):
print("loading prior")
prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
prior_model = prior_model_from_original_config()
prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint)
del prior_checkpoint
load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model)
print("done loading prior")
return prior_model
def prior_image(*, args, checkpoint_map_location):
print("loading prior_image")
print(f"load checkpoint from {args.prior_image_checkpoint_path}")
prior_checkpoint = torch.load(args.prior_image_checkpoint_path, map_location=checkpoint_map_location)
prior_model = prior_image_model_from_original_config()
prior_diffusers_checkpoint = prior_image_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint)
del prior_checkpoint
load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model)
print("done loading prior_image")
return prior_model
def renderer(*, args, checkpoint_map_location):
print(" loading renderer")
renderer_checkpoint = torch.load(args.transmitter_checkpoint_path, map_location=checkpoint_map_location)
renderer_model = renderer_model_from_original_config()
renderer_diffusers_checkpoint = renderer_model_original_checkpoint_to_diffusers_checkpoint(
renderer_model, renderer_checkpoint
)
del renderer_checkpoint
load_checkpoint_to_model(renderer_diffusers_checkpoint, renderer_model, strict=True)
print("done loading renderer")
return renderer_model
# prior model will expect clip_mean and clip_std, whic are missing from the state_dict
PRIOR_EXPECTED_MISSING_KEYS = ["clip_mean", "clip_std"]
def load_prior_checkpoint_to_model(checkpoint, model):
with tempfile.NamedTemporaryFile() as file:
torch.save(checkpoint, file.name)
del checkpoint
missing_keys, unexpected_keys = model.load_state_dict(torch.load(file.name), strict=False)
missing_keys = list(set(missing_keys) - set(PRIOR_EXPECTED_MISSING_KEYS))
if len(unexpected_keys) > 0:
raise ValueError(f"Unexpected keys when loading prior model: {unexpected_keys}")
if len(missing_keys) > 0:
raise ValueError(f"Missing keys when loading prior model: {missing_keys}")
def load_checkpoint_to_model(checkpoint, model, strict=False):
with tempfile.NamedTemporaryFile() as file:
torch.save(checkpoint, file.name)
del checkpoint
if strict:
model.load_state_dict(torch.load(file.name), strict=True)
else:
load_checkpoint_and_dispatch(model, file.name, device_map="auto")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--prior_checkpoint_path",
default=None,
type=str,
required=False,
help="Path to the prior checkpoint to convert.",
)
parser.add_argument(
"--prior_image_checkpoint_path",
default=None,
type=str,
required=False,
help="Path to the prior_image checkpoint to convert.",
)
parser.add_argument(
"--transmitter_checkpoint_path",
default=None,
type=str,
required=False,
help="Path to the transmitter checkpoint to convert.",
)
parser.add_argument(
"--checkpoint_load_device",
default="cpu",
type=str,
required=False,
help="The device passed to `map_location` when loading checkpoints.",
)
parser.add_argument(
"--debug",
default=None,
type=str,
required=False,
help="Only run a specific stage of the convert script. Used for debugging",
)
args = parser.parse_args()
print(f"loading checkpoints to {args.checkpoint_load_device}")
checkpoint_map_location = torch.device(args.checkpoint_load_device)
if args.debug is not None:
print(f"debug: only executing {args.debug}")
if args.debug is None:
print("YiYi TO-DO")
elif args.debug == "prior":
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
prior_model.save_pretrained(args.dump_path)
elif args.debug == "prior_image":
prior_model = prior_image(args=args, checkpoint_map_location=checkpoint_map_location)
prior_model.save_pretrained(args.dump_path)
elif args.debug == "renderer":
renderer_model = renderer(args=args, checkpoint_map_location=checkpoint_map_location)
renderer_model.save_pretrained(args.dump_path)
else:
raise ValueError(f"unknown debug value : {args.debug}")

View File

@@ -149,6 +149,8 @@ else:
LDMTextToImagePipeline,
PaintByExamplePipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,

View File

@@ -34,14 +34,33 @@ class PriorTransformer(ModelMixin, ConfigMixin):
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
embedding_dim (`int`, *optional*, defaults to 768):
The dimension of the CLIP embeddings. Image embeddings and text embeddings are both the same dimension.
num_embeddings (`int`, *optional*, defaults to 77): The max number of CLIP embeddings allowed (the
length of the prompt after it has been tokenized).
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
num_embeddings (`int`, *optional*, defaults to 77):
The number of embeddings of the model input `hidden_states`
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
additional_embeddings`.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
The activation function to use to create timestep embeddings.
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
passing to Transformer blocks. Set it to `None` if normalization is not needed.
embedding_proj_norm_type (`str`, *optional*, defaults to None):
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
needed.
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
`encoder_hidden_states` is `None`.
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
product between the text embedding and image embedding as proposed in the unclip paper
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
If None, will be set to `num_attention_heads * attention_head_dim`
embedding_proj_dim (`int`, *optional*, default to None):
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
clip_embed_dim (`int`, *optional*, default to None):
The dimension of the output. If None, will be set to `embedding_dim`.
"""
@register_to_config
@@ -54,6 +73,14 @@ class PriorTransformer(ModelMixin, ConfigMixin):
num_embeddings=77,
additional_embeddings=4,
dropout: float = 0.0,
time_embed_act_fn: str = "silu",
norm_in_type: Optional[str] = None, # layer
embedding_proj_norm_type: Optional[str] = None, # layer
encoder_hid_proj_type: Optional[str] = "linear", # linear
added_emb_type: Optional[str] = "prd", # prd
time_embed_dim: Optional[int] = None,
embedding_proj_dim: Optional[int] = None,
clip_embed_dim: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -61,17 +88,41 @@ class PriorTransformer(ModelMixin, ConfigMixin):
inner_dim = num_attention_heads * attention_head_dim
self.additional_embeddings = additional_embeddings
time_embed_dim = time_embed_dim or inner_dim
embedding_proj_dim = embedding_proj_dim or embedding_dim
clip_embed_dim = clip_embed_dim or embedding_dim
self.time_proj = Timesteps(inner_dim, True, 0)
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
self.proj_in = nn.Linear(embedding_dim, inner_dim)
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
if embedding_proj_norm_type is None:
self.embedding_proj_norm = None
elif embedding_proj_norm_type == "layer":
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
else:
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
if encoder_hid_proj_type is None:
self.encoder_hidden_states_proj = None
elif encoder_hid_proj_type == "linear":
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
else:
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
if added_emb_type == "prd":
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
elif added_emb_type is None:
self.prd_embedding = None
else:
raise ValueError(
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
)
self.transformer_blocks = nn.ModuleList(
[
@@ -87,8 +138,16 @@ class PriorTransformer(ModelMixin, ConfigMixin):
]
)
if norm_in_type == "layer":
self.norm_in = nn.LayerNorm(inner_dim)
elif norm_in_type is None:
self.norm_in = None
else:
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
self.norm_out = nn.LayerNorm(inner_dim)
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
@@ -97,8 +156,8 @@ class PriorTransformer(ModelMixin, ConfigMixin):
causal_attention_mask = causal_attention_mask[None, ...]
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
@@ -172,7 +231,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
hidden_states,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True,
):
@@ -217,23 +276,61 @@ class PriorTransformer(ModelMixin, ConfigMixin):
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
time_embeddings = self.time_embedding(timesteps_projected)
if self.embedding_proj_norm is not None:
proj_embedding = self.embedding_proj_norm(proj_embedding)
proj_embeddings = self.embedding_proj(proj_embedding)
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
hidden_states = self.proj_in(hidden_states)
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
additional_embeds = []
additional_embeddings_len = 0
if encoder_hidden_states is not None:
additional_embeds.append(encoder_hidden_states)
additional_embeddings_len += encoder_hidden_states.shape[1]
if len(proj_embeddings.shape) == 2:
proj_embeddings = proj_embeddings[:, None, :]
if len(hidden_states.shape) == 2:
hidden_states = hidden_states[:, None, :]
additional_embeds = additional_embeds + [
proj_embeddings,
time_embeddings[:, None, :],
hidden_states,
]
if self.prd_embedding is not None:
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
additional_embeds.append(prd_embedding)
hidden_states = torch.cat(
[
encoder_hidden_states,
proj_embeddings[:, None, :],
time_embeddings[:, None, :],
hidden_states[:, None, :],
prd_embedding,
],
additional_embeds,
dim=1,
)
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
if positional_embeddings.shape[1] < hidden_states.shape[1]:
positional_embeddings = F.pad(
positional_embeddings,
(
0,
0,
additional_embeddings_len,
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
),
value=0.0,
)
hidden_states = hidden_states + positional_embeddings
if attention_mask is not None:
@@ -242,11 +339,19 @@ class PriorTransformer(ModelMixin, ConfigMixin):
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask)
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states[:, -1]
if self.prd_embedding is not None:
hidden_states = hidden_states[:, -1]
else:
hidden_states = hidden_states[:, additional_embeddings_len:]
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
if not return_dict:

View File

@@ -77,6 +77,7 @@ else:
from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline,

View File

@@ -0,0 +1,27 @@
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
else:
from .camera import create_pan_cameras
from .pipeline_shap_e import ShapEPipeline
from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline
from .renderer import (
BoundingBoxVolume,
ImportanceRaySampler,
MLPNeRFModelOutput,
MLPNeRSTFModel,
ShapEParamsProjModel,
ShapERenderer,
StratifiedRaySampler,
VoidNeRFModel,
)

View File

@@ -0,0 +1,147 @@
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import torch
@dataclass
class DifferentiableProjectiveCamera:
"""
Implements a batch, differentiable, standard pinhole camera
"""
origin: torch.Tensor # [batch_size x 3]
x: torch.Tensor # [batch_size x 3]
y: torch.Tensor # [batch_size x 3]
z: torch.Tensor # [batch_size x 3]
width: int
height: int
x_fov: float
y_fov: float
shape: Tuple[int]
def __post_init__(self):
assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2
def resolution(self):
return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
def fov(self):
return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
def get_image_coords(self) -> torch.Tensor:
"""
:return: coords of shape (width * height, 2)
"""
pixel_indices = torch.arange(self.height * self.width)
coords = torch.stack(
[
pixel_indices % self.width,
torch.div(pixel_indices, self.width, rounding_mode="trunc"),
],
axis=1,
)
return coords
@property
def camera_rays(self):
batch_size, *inner_shape = self.shape
inner_batch_size = int(np.prod(inner_shape))
coords = self.get_image_coords()
coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
rays = self.get_camera_rays(coords)
rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3)
return rays
def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
batch_size, *shape, n_coords = coords.shape
assert n_coords == 2
assert batch_size == self.origin.shape[0]
flat = coords.view(batch_size, -1, 2)
res = self.resolution()
fov = self.fov()
fracs = (flat.float() / (res - 1)) * 2 - 1
fracs = fracs * torch.tan(fov / 2)
fracs = fracs.view(batch_size, -1, 2)
directions = (
self.z.view(batch_size, 1, 3)
+ self.x.view(batch_size, 1, 3) * fracs[:, :, :1]
+ self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]
)
directions = directions / directions.norm(dim=-1, keepdim=True)
rays = torch.stack(
[
torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]),
directions,
],
dim=2,
)
return rays.view(batch_size, *shape, 2, 3)
def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
"""
Creates a new camera for the resized view assuming the aspect ratio does not change.
"""
assert width * self.height == height * self.width, "The aspect ratio should not change."
return DifferentiableProjectiveCamera(
origin=self.origin,
x=self.x,
y=self.y,
z=self.z,
width=width,
height=height,
x_fov=self.x_fov,
y_fov=self.y_fov,
)
def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera:
origins = []
xs = []
ys = []
zs = []
for theta in np.linspace(0, 2 * np.pi, num=20):
z = np.array([np.sin(theta), np.cos(theta), -0.5])
z /= np.sqrt(np.sum(z**2))
origin = -z * 4
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
y = np.cross(z, x)
origins.append(origin)
xs.append(x)
ys.append(y)
zs.append(z)
return DifferentiableProjectiveCamera(
origin=torch.from_numpy(np.stack(origins, axis=0)).float(),
x=torch.from_numpy(np.stack(xs, axis=0)).float(),
y=torch.from_numpy(np.stack(ys, axis=0)).float(),
z=torch.from_numpy(np.stack(zs, axis=0)).float(),
width=size,
height=size,
x_fov=0.7,
y_fov=0.7,
shape=(1, len(xs)),
)

View File

@@ -0,0 +1,390 @@
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...models import PriorTransformer
from ...pipelines import DiffusionPipeline
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from .renderer import ShapERenderer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> from diffusers.utils import export_to_gif
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> repo = "openai/shap-e"
>>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
>>> pipe = pipe.to(device)
>>> guidance_scale = 15.0
>>> prompt = "a shark"
>>> images = pipe(
... prompt,
... guidance_scale=guidance_scale,
... num_inference_steps=64,
... frame_size=256,
... ).images
>>> gif_path = export_to_gif(images[0], "shark_3d.gif")
```
"""
@dataclass
class ShapEPipelineOutput(BaseOutput):
"""
Output class for ShapEPipeline.
Args:
images (`torch.FloatTensor`)
a list of images for 3D rendering
"""
images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]]
class ShapEPipeline(DiffusionPipeline):
"""
Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
prior ([`PriorTransformer`]):
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
text_encoder ([`CLIPTextModelWithProjection`]):
Frozen text-encoder.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`HeunDiscreteScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
renderer ([`ShapERenderer`]):
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
with the NeRF rendering method
"""
def __init__(
self,
prior: PriorTransformer,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
scheduler: HeunDiscreteScheduler,
renderer: ShapERenderer,
):
super().__init__()
self.register_modules(
prior=prior,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
renderer=renderer,
)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
models = [self.text_encoder, self.prior]
for cpu_offloaded_model in models:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
if self.safety_checker is not None:
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
return self.device
for module in self.text_encoder.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
):
len(prompt) if isinstance(prompt, list) else 1
# YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file
self.tokenizer.pad_token_id = 0
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_output = self.text_encoder(text_input_ids.to(device))
prompt_embeds = text_encoder_output.text_embeds
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
# in Shap-E it normalize the prompt_embeds and then later rescale it
prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
if do_classifier_free_guidance:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# Rescale the features to have unit variance
prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds
return prompt_embeds
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: str,
num_images_per_prompt: int = 1,
num_inference_steps: int = 25,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 4.0,
frame_size: int = 64,
output_type: Optional[str] = "pil", # pil, np, latent
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
frame_size (`int`, *optional*, default to 64):
the width and height of each image frame of the generated 3d output
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
(`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`ShapEPipelineOutput`] or `tuple`
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
device = self._execution_device
batch_size = batch_size * num_images_per_prompt
do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
# prior
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
num_embeddings = self.prior.config.num_embeddings
embedding_dim = self.prior.config.embedding_dim
latents = self.prepare_latents(
(batch_size, num_embeddings * embedding_dim),
prompt_embeds.dtype,
device,
generator,
latents,
self.scheduler,
)
# YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.prior(
scaled_model_input,
timestep=t,
proj_embedding=prompt_embeds,
).predicted_image_embedding
# remove the variance
noise_pred, _ = noise_pred.split(
scaled_model_input.shape[2], dim=2
) # batch_size, num_embeddings, embedding_dim
if do_classifier_free_guidance is not None:
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
timestep=t,
sample=latents,
).prev_sample
if output_type == "latent":
return ShapEPipelineOutput(images=latents)
images = []
for i, latent in enumerate(latents):
image = self.renderer.decode(
latent[None, :],
device,
size=frame_size,
ray_batch_size=4096,
n_coarse_samples=64,
n_fine_samples=128,
)
images.append(image)
images = torch.stack(images)
if output_type not in ["np", "pil"]:
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
images = images.cpu().numpy()
if output_type == "pil":
images = [self.numpy_to_pil(image) for image in images]
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (images,)
return ShapEPipelineOutput(images=images)

View File

@@ -0,0 +1,349 @@
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import CLIPImageProcessor, CLIPVisionModel
from ...models import PriorTransformer
from ...pipelines import DiffusionPipeline
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
is_accelerate_available,
logging,
randn_tensor,
replace_example_docstring,
)
from .renderer import ShapERenderer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from PIL import Image
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> from diffusers.utils import export_to_gif, load_image
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> repo = "openai/shap-e-img2img"
>>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
>>> pipe = pipe.to(device)
>>> guidance_scale = 3.0
>>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png"
>>> image = load_image(image_url).convert("RGB")
>>> images = pipe(
... image,
... guidance_scale=guidance_scale,
... num_inference_steps=64,
... frame_size=256,
... ).images
>>> gif_path = export_to_gif(images[0], "corgi_3d.gif")
```
"""
@dataclass
class ShapEPipelineOutput(BaseOutput):
"""
Output class for ShapEPipeline.
Args:
images (`torch.FloatTensor`)
a list of images for 3D rendering
"""
images: Union[PIL.Image.Image, np.ndarray]
class ShapEImg2ImgPipeline(DiffusionPipeline):
"""
Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
prior ([`PriorTransformer`]):
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
text_encoder ([`CLIPTextModelWithProjection`]):
Frozen text-encoder.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`HeunDiscreteScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
renderer ([`ShapERenderer`]):
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
with the NeRF rendering method
"""
def __init__(
self,
prior: PriorTransformer,
image_encoder: CLIPVisionModel,
image_processor: CLIPImageProcessor,
scheduler: HeunDiscreteScheduler,
renderer: ShapERenderer,
):
super().__init__()
self.register_modules(
prior=prior,
image_encoder=image_encoder,
image_processor=image_processor,
scheduler=scheduler,
renderer=renderer,
)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
models = [self.image_encoder, self.prior]
for cpu_offloaded_model in models:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.image_encoder, "_hf_hook"):
return self.device
for module in self.image_encoder.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_image(
self,
image,
device,
num_images_per_prompt,
do_classifier_free_guidance,
):
if isinstance(image, List) and isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
if not isinstance(image, torch.Tensor):
image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0)
image = image.to(dtype=self.image_encoder.dtype, device=device)
image_embeds = self.image_encoder(image)["last_hidden_state"]
image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
negative_image_embeds = torch.zeros_like(image_embeds)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeds = torch.cat([negative_image_embeds, image_embeds])
return image_embeds
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image]],
num_images_per_prompt: int = 1,
num_inference_steps: int = 25,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 4.0,
frame_size: int = 64,
output_type: Optional[str] = "pil", # pil, np, latent
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
frame_size (`int`, *optional*, default to 64):
the width and height of each image frame of the generated 3d output
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
(`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`ShapEPipelineOutput`] or `tuple`
"""
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, torch.Tensor):
batch_size = image.shape[0]
elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)):
batch_size = len(image)
else:
raise ValueError(
f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}"
)
device = self._execution_device
batch_size = batch_size * num_images_per_prompt
do_classifier_free_guidance = guidance_scale > 1.0
image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
# prior
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
num_embeddings = self.prior.config.num_embeddings
embedding_dim = self.prior.config.embedding_dim
latents = self.prepare_latents(
(batch_size, num_embeddings * embedding_dim),
image_embeds.dtype,
device,
generator,
latents,
self.scheduler,
)
# YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.prior(
scaled_model_input,
timestep=t,
proj_embedding=image_embeds,
).predicted_image_embedding
# remove the variance
noise_pred, _ = noise_pred.split(
scaled_model_input.shape[2], dim=2
) # batch_size, num_embeddings, embedding_dim
if do_classifier_free_guidance is not None:
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
timestep=t,
sample=latents,
).prev_sample
if output_type == "latent":
return ShapEPipelineOutput(images=latents)
images = []
for i, latent in enumerate(latents):
print()
image = self.renderer.decode(
latent[None, :],
device,
size=frame_size,
ray_batch_size=4096,
n_coarse_samples=64,
n_fine_samples=128,
)
images.append(image)
images = torch.stack(images)
if output_type not in ["np", "pil"]:
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
images = images.cpu().numpy()
if output_type == "pil":
images = [self.numpy_to_pil(image) for image in images]
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (images,)
return ShapEPipelineOutput(images=images)

View File

@@ -0,0 +1,709 @@
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...utils import BaseOutput
from .camera import create_pan_cameras
def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor:
r"""
Sample from the given discrete probability distribution with replacement.
The i-th bin is assumed to have mass pmf[i].
Args:
pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all()
n_samples: number of samples
Return:
indices sampled with replacement
"""
*shape, support_size, last_dim = pmf.shape
assert last_dim == 1
cdf = torch.cumsum(pmf.view(-1, support_size), dim=1)
inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device))
return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1)
def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:
"""
Concatenate x and its positional encodings, following NeRF.
Reference: https://arxiv.org/pdf/2210.04628.pdf
"""
if min_deg == max_deg:
return x
scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device)
*shape, dim = x.shape
xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)
assert xb.shape[-1] == dim * (max_deg - min_deg)
emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()
return torch.cat([x, emb], dim=-1)
def encode_position(position):
return posenc_nerf(position, min_deg=0, max_deg=15)
def encode_direction(position, direction=None):
if direction is None:
return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))
else:
return posenc_nerf(direction, min_deg=0, max_deg=8)
def _sanitize_name(x: str) -> str:
return x.replace(".", "__")
def integrate_samples(volume_range, ts, density, channels):
r"""
Function integrating the model output.
Args:
volume_range: Specifies the integral range [t0, t1]
ts: timesteps
density: torch.Tensor [batch_size, *shape, n_samples, 1]
channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
returns:
channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density
*transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume
)
"""
# 1. Calculate the weights
_, _, dt = volume_range.partition(ts)
ddensity = density * dt
mass = torch.cumsum(ddensity, dim=-2)
transmittance = torch.exp(-mass[..., -1, :])
alphas = 1.0 - torch.exp(-ddensity)
Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
# This is the probability of light hitting and reflecting off of
# something at depth [..., i, :].
weights = alphas * Ts
# 2. Integrate channels
channels = torch.sum(channels * weights, dim=-2)
return channels, weights, transmittance
class VoidNeRFModel(nn.Module):
"""
Implements the default empty space model where all queries are rendered as background.
"""
def __init__(self, background, channel_scale=255.0):
super().__init__()
background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale)
self.register_buffer("background", background)
def forward(self, position):
background = self.background[None].to(position.device)
shape = position.shape[:-1]
ones = [1] * (len(shape) - 1)
n_channels = background.shape[-1]
background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels])
return background
@dataclass
class VolumeRange:
t0: torch.Tensor
t1: torch.Tensor
intersected: torch.Tensor
def __post_init__(self):
assert self.t0.shape == self.t1.shape == self.intersected.shape
def partition(self, ts):
"""
Partitions t0 and t1 into n_samples intervals.
Args:
ts: [batch_size, *shape, n_samples, 1]
Return:
lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size,
*shape, n_samples, 1]
where
ts \\in [lower, upper] deltas = upper - lower
"""
mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5
lower = torch.cat([self.t0[..., None, :], mids], dim=-2)
upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)
delta = upper - lower
assert lower.shape == upper.shape == delta.shape == ts.shape
return lower, upper, delta
class BoundingBoxVolume(nn.Module):
"""
Axis-aligned bounding box defined by the two opposite corners.
"""
def __init__(
self,
*,
bbox_min,
bbox_max,
min_dist: float = 0.0,
min_t_range: float = 1e-3,
):
"""
Args:
bbox_min: the left/bottommost corner of the bounding box
bbox_max: the other corner of the bounding box
min_dist: all rays should start at least this distance away from the origin.
"""
super().__init__()
self.min_dist = min_dist
self.min_t_range = min_t_range
self.bbox_min = torch.tensor(bbox_min)
self.bbox_max = torch.tensor(bbox_max)
self.bbox = torch.stack([self.bbox_min, self.bbox_max])
assert self.bbox.shape == (2, 3)
assert min_dist >= 0.0
assert min_t_range > 0.0
def intersect(
self,
origin: torch.Tensor,
direction: torch.Tensor,
t0_lower: Optional[torch.Tensor] = None,
epsilon=1e-6,
):
"""
Args:
origin: [batch_size, *shape, 3]
direction: [batch_size, *shape, 3]
t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
params: Optional meta parameters in case Volume is parametric
epsilon: to stabilize calculations
Return:
A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with
the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to
be on the boundary of the volume.
"""
batch_size, *shape, _ = origin.shape
ones = [1] * len(shape)
bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device)
def _safe_divide(a, b, epsilon=1e-6):
return a / torch.where(b < 0, b - epsilon, b + epsilon)
ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)
# Cases to think about:
#
# 1. t1 <= t0: the ray does not pass through the AABB.
# 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin.
# 3. t0 <= 0 <= t1: the ray starts from inside the BB
# 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice.
#
# 1 and 4 are clearly handled from t0 < t1 below.
# Making t0 at least min_dist (>= 0) takes care of 2 and 3.
t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)
t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values
assert t0.shape == t1.shape == (batch_size, *shape, 1)
if t0_lower is not None:
assert t0.shape == t0_lower.shape
t0 = torch.maximum(t0, t0_lower)
intersected = t0 + self.min_t_range < t1
t0 = torch.where(intersected, t0, torch.zeros_like(t0))
t1 = torch.where(intersected, t1, torch.ones_like(t1))
return VolumeRange(t0=t0, t1=t1, intersected=intersected)
class StratifiedRaySampler(nn.Module):
"""
Instead of fixed intervals, a sample is drawn uniformly at random from each interval.
"""
def __init__(self, depth_mode: str = "linear"):
"""
:param depth_mode: linear samples ts linearly in depth. harmonic ensures
closer points are sampled more densely.
"""
self.depth_mode = depth_mode
assert self.depth_mode in ("linear", "geometric", "harmonic")
def sample(
self,
t0: torch.Tensor,
t1: torch.Tensor,
n_samples: int,
epsilon: float = 1e-3,
) -> torch.Tensor:
"""
Args:
t0: start time has shape [batch_size, *shape, 1]
t1: finish time has shape [batch_size, *shape, 1]
n_samples: number of ts to sample
Return:
sampled ts of shape [batch_size, *shape, n_samples, 1]
"""
ones = [1] * (len(t0.shape) - 1)
ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)
if self.depth_mode == "linear":
ts = t0 * (1.0 - ts) + t1 * ts
elif self.depth_mode == "geometric":
ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
elif self.depth_mode == "harmonic":
# The original NeRF recommends this interpolation scheme for
# spherical scenes, but there could be some weird edge cases when
# the observer crosses from the inner to outer volume.
ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)
mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
upper = torch.cat([mids, t1], dim=-1)
lower = torch.cat([t0, mids], dim=-1)
# yiyi notes: add a random seed here for testing, don't forget to remove
torch.manual_seed(0)
t_rand = torch.rand_like(ts)
ts = lower + (upper - lower) * t_rand
return ts.unsqueeze(-1)
class ImportanceRaySampler(nn.Module):
"""
Given the initial estimate of densities, this samples more from regions/bins expected to have objects.
"""
def __init__(
self,
volume_range: VolumeRange,
ts: torch.Tensor,
weights: torch.Tensor,
blur_pool: bool = False,
alpha: float = 1e-5,
):
"""
Args:
volume_range: the range in which a ray intersects the given volume.
ts: earlier samples from the coarse rendering step
weights: discretized version of density * transmittance
blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.
alpha: small value to add to weights.
"""
self.volume_range = volume_range
self.ts = ts.clone().detach()
self.weights = weights.clone().detach()
self.blur_pool = blur_pool
self.alpha = alpha
@torch.no_grad()
def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
"""
Args:
t0: start time has shape [batch_size, *shape, 1]
t1: finish time has shape [batch_size, *shape, 1]
n_samples: number of ts to sample
Return:
sampled ts of shape [batch_size, *shape, n_samples, 1]
"""
lower, upper, _ = self.volume_range.partition(self.ts)
batch_size, *shape, n_coarse_samples, _ = self.ts.shape
weights = self.weights
if self.blur_pool:
padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
weights = weights + self.alpha
pmf = weights / weights.sum(dim=-2, keepdim=True)
inds = sample_pmf(pmf, n_samples)
assert inds.shape == (batch_size, *shape, n_samples, 1)
assert (inds >= 0).all() and (inds < n_coarse_samples).all()
t_rand = torch.rand(inds.shape, device=inds.device)
lower_ = torch.gather(lower, -2, inds)
upper_ = torch.gather(upper, -2, inds)
ts = lower_ + (upper_ - lower_) * t_rand
ts = torch.sort(ts, dim=-2).values
return ts
@dataclass
class MLPNeRFModelOutput(BaseOutput):
density: torch.Tensor
signed_distance: torch.Tensor
channels: torch.Tensor
ts: torch.Tensor
class MLPNeRSTFModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
d_hidden: int = 256,
n_output: int = 12,
n_hidden_layers: int = 6,
act_fn: str = "swish",
insert_direction_at: int = 4,
):
super().__init__()
# Instantiate the MLP
# Find out the dimension of encoded position and direction
dummy = torch.eye(1, 3)
d_posenc_pos = encode_position(position=dummy).shape[-1]
d_posenc_dir = encode_direction(position=dummy).shape[-1]
mlp_widths = [d_hidden] * n_hidden_layers
input_widths = [d_posenc_pos] + mlp_widths
output_widths = mlp_widths + [n_output]
if insert_direction_at is not None:
input_widths[insert_direction_at] += d_posenc_dir
self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)])
if act_fn == "swish":
# self.activation = swish
# yiyi testing:
self.activation = lambda x: F.silu(x)
else:
raise ValueError(f"Unsupported activation function {act_fn}")
self.sdf_activation = torch.tanh
self.density_activation = torch.nn.functional.relu
self.channel_activation = torch.sigmoid
def map_indices_to_keys(self, output):
h_map = {
"sdf": (0, 1),
"density_coarse": (1, 2),
"density_fine": (2, 3),
"stf": (3, 6),
"nerf_coarse": (6, 9),
"nerf_fine": (9, 12),
}
mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()}
return mapped_output
def forward(self, *, position, direction, ts, nerf_level="coarse"):
h = encode_position(position)
h_preact = h
h_directionless = None
for i, layer in enumerate(self.mlp):
if i == self.config.insert_direction_at: # 4 in the config
h_directionless = h_preact
h_direction = encode_direction(position, direction=direction)
h = torch.cat([h, h_direction], dim=-1)
h = layer(h)
h_preact = h
if i < len(self.mlp) - 1:
h = self.activation(h)
h_final = h
if h_directionless is None:
h_directionless = h_preact
activation = self.map_indices_to_keys(h_final)
if nerf_level == "coarse":
h_density = activation["density_coarse"]
h_channels = activation["nerf_coarse"]
else:
h_density = activation["density_fine"]
h_channels = activation["nerf_fine"]
density = self.density_activation(h_density)
signed_distance = self.sdf_activation(activation["sdf"])
channels = self.channel_activation(h_channels)
# yiyi notes: I think signed_distance is not used
return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts)
class ChannelsProj(nn.Module):
def __init__(
self,
*,
vectors: int,
channels: int,
d_latent: int,
):
super().__init__()
self.proj = nn.Linear(d_latent, vectors * channels)
self.norm = nn.LayerNorm(channels)
self.d_latent = d_latent
self.vectors = vectors
self.channels = channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_bvd = x
w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
b_vc = self.proj.bias.view(1, self.vectors, self.channels)
h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
h = self.norm(h)
h = h + b_vc
return h
class ShapEParamsProjModel(ModelMixin, ConfigMixin):
"""
project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP).
For more details, see the original paper:
"""
@register_to_config
def __init__(
self,
*,
param_names: Tuple[str] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
"nerstf.mlp.3.weight",
),
param_shapes: Tuple[Tuple[int]] = (
(256, 93),
(256, 256),
(256, 256),
(256, 256),
),
d_latent: int = 1024,
):
super().__init__()
# check inputs
if len(param_names) != len(param_shapes):
raise ValueError("Must provide same number of `param_names` as `param_shapes`")
self.projections = nn.ModuleDict({})
for k, (vectors, channels) in zip(param_names, param_shapes):
self.projections[_sanitize_name(k)] = ChannelsProj(
vectors=vectors,
channels=channels,
d_latent=d_latent,
)
def forward(self, x: torch.Tensor):
out = {}
start = 0
for k, shape in zip(self.config.param_names, self.config.param_shapes):
vectors, _ = shape
end = start + vectors
x_bvd = x[:, start:end]
out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
start = end
return out
class ShapERenderer(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
*,
param_names: Tuple[str] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
"nerstf.mlp.3.weight",
),
param_shapes: Tuple[Tuple[int]] = (
(256, 93),
(256, 256),
(256, 256),
(256, 256),
),
d_latent: int = 1024,
d_hidden: int = 256,
n_output: int = 12,
n_hidden_layers: int = 6,
act_fn: str = "swish",
insert_direction_at: int = 4,
background: Tuple[float] = (
255.0,
255.0,
255.0,
),
):
super().__init__()
self.params_proj = ShapEParamsProjModel(
param_names=param_names,
param_shapes=param_shapes,
d_latent=d_latent,
)
self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
self.void = VoidNeRFModel(background=background, channel_scale=255.0)
self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
@torch.no_grad()
def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False):
"""
Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below
with some abuse of notations)
C(r) := sum(
transmittance(t[i]) * integrate(
lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]],
) for i in range(len(parts))
) + transmittance(t[-1]) * void_model(t[-1]).channels
where
1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through
the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are
obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t
where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the
shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and
transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1],
math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
args:
rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples:
number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including
:return: A tuple of
- `channels`
- A importance samplers for additional fine-grained rendering
- raw model output
"""
origin, direction = rays[..., 0, :], rays[..., 1, :]
# Integrate over [t[i], t[i + 1]]
# 1 Intersect the rays with the current volume and sample ts to integrate along.
vrange = self.volume.intersect(origin, direction, t0_lower=None)
ts = sampler.sample(vrange.t0, vrange.t1, n_samples)
ts = ts.to(rays.dtype)
if prev_model_out is not None:
# Append the previous ts now before fprop because previous
# rendering used a different model and we can't reuse the output.
ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values
batch_size, *_shape, _t0_dim = vrange.t0.shape
_, *ts_shape, _ts_dim = ts.shape
# 2. Get the points along the ray and query the model
directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])
positions = origin.unsqueeze(-2) + ts * directions
directions = directions.to(self.mlp.dtype)
positions = positions.to(self.mlp.dtype)
optional_directions = directions if render_with_direction else None
model_out = self.mlp(
position=positions,
direction=optional_directions,
ts=ts,
nerf_level="coarse" if prev_model_out is None else "fine",
)
# 3. Integrate the model results
channels, weights, transmittance = integrate_samples(
vrange, model_out.ts, model_out.density, model_out.channels
)
# 4. Clean up results that do not intersect with the volume.
transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance))
channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels))
# 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
channels = channels + transmittance * self.void(origin)
weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights)
return channels, weighted_sampler, model_out
@torch.no_grad()
def decode(
self,
latents,
device,
size: int = 64,
ray_batch_size: int = 4096,
n_coarse_samples=64,
n_fine_samples=128,
):
# project the the paramters from the generated latents
projected_params = self.params_proj(latents)
# update the mlp layers of the renderer
for name, param in self.mlp.state_dict().items():
if f"nerstf.{name}" in projected_params.keys():
param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))
# create cameras object
camera = create_pan_cameras(size)
rays = camera.camera_rays
rays = rays.to(device)
n_batches = rays.shape[1] // ray_batch_size
coarse_sampler = StratifiedRaySampler()
images = []
for idx in range(n_batches):
rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]
# render rays with coarse, stratified samples.
_, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples)
# Then, render with additional importance-weighted ray samples.
channels, _, _ = self.render_rays(
rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out
)
images.append(channels)
images = torch.cat(images, dim=1)
images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)
return images

View File

@@ -47,7 +47,11 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -46,7 +46,11 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -59,19 +63,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -47,7 +47,11 @@ class DDIMParallelSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -44,7 +44,11 @@ class DDPMSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -57,19 +61,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -46,7 +46,11 @@ class DDPMParallelSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -59,19 +63,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -26,7 +26,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -39,19 +43,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -26,7 +26,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -39,19 +43,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -26,7 +26,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -39,19 +43,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -76,7 +77,11 @@ class BrownianTreeNoiseSampler:
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -89,19 +94,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@@ -190,10 +206,16 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
indices = (schedule_timesteps == timestep).nonzero()
if self.state_in_first_order:
pos = -1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
pos = 0
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
@property
@@ -292,6 +314,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
self.mid_point_sigma = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
def _second_order_timesteps(self, sigmas, log_sigmas):
def sigma_fn(_t):
return np.exp(-_t)
@@ -373,6 +399,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
"""
step_index = self.index_for_timestep(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
# Create a noise sampler if it hasn't been created yet
if self.noise_sampler is None:
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()

View File

@@ -29,7 +29,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -42,19 +46,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -47,7 +47,11 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -47,7 +47,11 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -23,7 +24,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -36,19 +41,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@@ -74,6 +90,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf).
clip_sample (`bool`, default `True`):
option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
@@ -100,6 +120,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False,
clip_sample: Optional[bool] = False,
clip_sample_range: float = 1.0,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
@@ -114,7 +136,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
elif beta_schedule == "exp":
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
@@ -131,10 +155,16 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
indices = (schedule_timesteps == timestep).nonzero()
if self.state_in_first_order:
pos = -1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
pos = 0
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
@property
@@ -207,7 +237,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.use_karras_sigmas:
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
@@ -228,6 +258,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.prev_derivative = None
self.dt = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
@@ -292,6 +326,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
step_index = self.index_for_timestep(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
@@ -316,12 +354,17 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample / (sigma_input**2 + 1)
)
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
if self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
if self.state_in_first_order:
# 2. Convert to an ODE derivative for 1st order
derivative = (sample - pred_original_sample) / sigma_hat

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -24,7 +25,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -37,19 +42,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@@ -130,10 +146,16 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
indices = (schedule_timesteps == timestep).nonzero()
if self.state_in_first_order:
pos = -1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
pos = 0
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
@property
@@ -245,6 +267,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
def sigma_to_t(self, sigma):
# get log sigma
log_sigma = sigma.log()
@@ -295,6 +321,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
step_index = self.index_for_timestep(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_interpol = self.sigmas_interpol[step_index]

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -23,7 +24,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -36,19 +41,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@@ -129,10 +145,16 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
indices = (schedule_timesteps == timestep).nonzero()
if self.state_in_first_order:
pos = -1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
pos = 0
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
@property
@@ -234,6 +256,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
def sigma_to_t(self, sigma):
# get log sigma
log_sigma = sigma.log()
@@ -283,6 +309,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
step_index = self.index_for_timestep(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_interpol = self.sigmas_interpol[step_index + 1]

View File

@@ -45,7 +45,11 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -58,19 +62,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -25,7 +25,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -38,19 +42,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -43,7 +43,11 @@ class RePaintSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -56,19 +60,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -44,7 +44,11 @@ class UnCLIPSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -57,19 +61,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

View File

@@ -104,7 +104,7 @@ if is_torch_available():
)
from .torch_utils import maybe_allow_in_graph
from .testing_utils import export_to_video
from .testing_utils import export_to_gif, export_to_video
logger = get_logger(__name__)

View File

@@ -377,6 +377,36 @@ class SemanticStableDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ShapEImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ShapEPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -300,6 +300,21 @@ def preprocess_image(image: PIL.Image, batch_size: int):
return 2.0 * image - 1.0
def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str:
if output_gif_path is None:
output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name
image[0].save(
output_gif_path,
save_all=True,
append_images=image[1:],
optimize=False,
duration=100,
loop=0,
)
return output_gif_path
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
if is_opencv_available():
import cv2

View File

View File

@@ -0,0 +1,265 @@
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEPipeline
from diffusers.pipelines.shap_e import ShapERenderer
from diffusers.utils import load_numpy, slow
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = ShapEPipeline
params = ["prompt"]
batch_params = ["prompt"]
required_optional_params = [
"num_images_per_prompt",
"num_inference_steps",
"generator",
"latents",
"guidance_scale",
"frame_size",
"output_type",
"return_dict",
]
test_xformers_attention = False
@property
def text_embedder_hidden_size(self):
return 32
@property
def time_input_dim(self):
return 32
@property
def time_embed_dim(self):
return self.time_input_dim * 4
@property
def renderer_dim(self):
return 8
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=self.text_embedder_hidden_size,
projection_dim=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModelWithProjection(config)
@property
def dummy_prior(self):
torch.manual_seed(0)
model_kwargs = {
"num_attention_heads": 2,
"attention_head_dim": 16,
"embedding_dim": self.time_input_dim,
"num_embeddings": 32,
"embedding_proj_dim": self.text_embedder_hidden_size,
"time_embed_dim": self.time_embed_dim,
"num_layers": 1,
"clip_embed_dim": self.time_input_dim * 2,
"additional_embeddings": 0,
"time_embed_act_fn": "gelu",
"norm_in_type": "layer",
"encoder_hid_proj_type": None,
"added_emb_type": None,
}
model = PriorTransformer(**model_kwargs)
return model
@property
def dummy_renderer(self):
torch.manual_seed(0)
model_kwargs = {
"param_shapes": (
(self.renderer_dim, 93),
(self.renderer_dim, 8),
(self.renderer_dim, 8),
(self.renderer_dim, 8),
),
"d_latent": self.time_input_dim,
"d_hidden": self.renderer_dim,
"n_output": 12,
"background": (
0.1,
0.1,
0.1,
),
}
model = ShapERenderer(**model_kwargs)
return model
def get_dummy_components(self):
prior = self.dummy_prior
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
renderer = self.dummy_renderer
scheduler = HeunDiscreteScheduler(
beta_schedule="exp",
num_train_timesteps=1024,
prediction_type="sample",
use_karras_sigmas=True,
clip_sample=True,
clip_sample_range=1.0,
)
components = {
"prior": prior,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"renderer": renderer,
"scheduler": scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "horse",
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
"output_type": "np",
}
return inputs
def test_shap_e(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (20, 32, 32, 3)
expected_slice = np.array(
[
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_batch_consistent(self):
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
self._test_inference_batch_consistent(batch_sizes=[1, 2])
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
self._test_inference_batch_single_identical(
batch_size=2,
test_max_difference=test_max_difference,
relax_max_difference=relax_max_difference,
)
def test_num_images_per_prompt(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_size = 1
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
assert images.shape[0] == batch_size * num_images_per_prompt
@slow
@require_torch_gpu
class ShapEPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_shap_e(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/shap_e/test_shap_e_np_out.npy"
)
pipe = ShapEPipeline.from_pretrained("openai/shap-e")
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
images = pipe(
"a shark",
generator=generator,
guidance_scale=15.0,
num_inference_steps=64,
frame_size=64,
output_type="np",
).images[0]
assert images.shape == (20, 64, 64, 3)
assert_mean_pixel_difference(images, expected_image)

View File

@@ -0,0 +1,281 @@
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import numpy as np
import torch
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEImg2ImgPipeline
from diffusers.pipelines.shap_e import ShapERenderer
from diffusers.utils import floats_tensor, load_image, load_numpy, slow
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = ShapEImg2ImgPipeline
params = ["image"]
batch_params = ["image"]
required_optional_params = [
"num_images_per_prompt",
"num_inference_steps",
"generator",
"latents",
"guidance_scale",
"frame_size",
"output_type",
"return_dict",
]
test_xformers_attention = False
@property
def text_embedder_hidden_size(self):
return 32
@property
def time_input_dim(self):
return 32
@property
def time_embed_dim(self):
return self.time_input_dim * 4
@property
def renderer_dim(self):
return 8
@property
def dummy_image_encoder(self):
torch.manual_seed(0)
config = CLIPVisionConfig(
hidden_size=self.text_embedder_hidden_size,
image_size=64,
projection_dim=self.text_embedder_hidden_size,
intermediate_size=37,
num_attention_heads=4,
num_channels=3,
num_hidden_layers=5,
patch_size=1,
)
model = CLIPVisionModel(config)
return model
@property
def dummy_image_processor(self):
image_processor = CLIPImageProcessor(
crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224,
)
return image_processor
@property
def dummy_prior(self):
torch.manual_seed(0)
model_kwargs = {
"num_attention_heads": 2,
"attention_head_dim": 16,
"embedding_dim": self.time_input_dim,
"num_embeddings": 32,
"embedding_proj_dim": self.text_embedder_hidden_size,
"time_embed_dim": self.time_embed_dim,
"num_layers": 1,
"clip_embed_dim": self.time_input_dim * 2,
"additional_embeddings": 0,
"time_embed_act_fn": "gelu",
"norm_in_type": "layer",
"embedding_proj_norm_type": "layer",
"encoder_hid_proj_type": None,
"added_emb_type": None,
}
model = PriorTransformer(**model_kwargs)
return model
@property
def dummy_renderer(self):
torch.manual_seed(0)
model_kwargs = {
"param_shapes": (
(self.renderer_dim, 93),
(self.renderer_dim, 8),
(self.renderer_dim, 8),
(self.renderer_dim, 8),
),
"d_latent": self.time_input_dim,
"d_hidden": self.renderer_dim,
"n_output": 12,
"background": (
0.1,
0.1,
0.1,
),
}
model = ShapERenderer(**model_kwargs)
return model
def get_dummy_components(self):
prior = self.dummy_prior
image_encoder = self.dummy_image_encoder
image_processor = self.dummy_image_processor
renderer = self.dummy_renderer
scheduler = HeunDiscreteScheduler(
beta_schedule="exp",
num_train_timesteps=1024,
prediction_type="sample",
use_karras_sigmas=True,
clip_sample=True,
clip_sample_range=1.0,
)
components = {
"prior": prior,
"image_encoder": image_encoder,
"image_processor": image_processor,
"renderer": renderer,
"scheduler": scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
input_image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"image": input_image,
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
"output_type": "np",
}
return inputs
def test_shap_e(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (20, 32, 32, 3)
expected_slice = np.array(
[
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_batch_consistent(self):
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
self._test_inference_batch_consistent(batch_sizes=[1, 2])
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
self._test_inference_batch_single_identical(
batch_size=2,
test_max_difference=test_max_difference,
relax_max_difference=relax_max_difference,
)
def test_num_images_per_prompt(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_size = 1
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
assert images.shape[0] == batch_size * num_images_per_prompt
@slow
@require_torch_gpu
class ShapEImg2ImgPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_shap_e_img2img(self):
input_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/shap_e/corgi.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/shap_e/test_shap_e_img2img_out.npy"
)
pipe = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img")
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
images = pipe(
input_image,
generator=generator,
guidance_scale=3.0,
num_inference_steps=64,
frame_size=64,
output_type="np",
).images[0]
assert images.shape == (20, 64, 64, 3)
assert_mean_pixel_difference(images, expected_image)

View File

@@ -30,11 +30,15 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "scaled_linear"]:
for schedule in ["linear", "scaled_linear", "exp"]:
self.check_over_configs(beta_schedule=schedule)
def test_clip_sample(self):
for clip_sample_range in [1.0, 2.0, 3.0]:
self.check_over_configs(clip_sample_range=clip_sample_range, clip_sample=True)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
for prediction_type in ["epsilon", "v_prediction", "sample"]:
self.check_over_configs(prediction_type=prediction_type)
def test_full_loop_no_noise(self):