mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add Shap-E (#3742)
* refactor prior_transformer adding conversion script add pipeline add step_index from pipeline, + remove permute add zero pad token remove copy from statement for betas_for_alpha_bar function * add * add * update conversion script for renderer model * refactor camera a little bit * clean up * style * fix copies * Update src/diffusers/schedulers/scheduling_heun_discrete.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * alpha_transform_type * remove step_index argument * remove get_sigmas_karras * remove _yiyi_sigma_to_t * move the rescale prompt_embeds from prior_transformer to pipeline * replace baddbmm with einsum to match origial repo * Revert "replace baddbmm with einsum to match origial repo" This reverts commit3f6b435d65. * add step_index to scale_model_input * Revert "move the rescale prompt_embeds from prior_transformer to pipeline" This reverts commit5b5a8e6be9. * move rescale from prior_transformer to pipeline * correct step_index in scale_model_input * remove print lines * refactor prior - reduce arguments * make style * add prior_image * arg embedding_proj_norm -> norm_embedding_proj * add pre-norm for proj_embedding * move rescale prompt from pipeline to _encode_prompt * add img2img pipeline * style * copies * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/prior_transformer.py add arg: encoder_hid_proj Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/prior_transformer.py add new config: norm_in_type Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/prior_transformer.py add new config: added_emb_type Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/prior_transformer.py rename out_dim -> clip_embed_dim Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/prior_transformer.py rename config: out_dim -> clip_embed_dim Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/prior_transformer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * finish refactor prior_tranformer * make style * refactor renderer * fix * make style * refactor img2img * remove params_proj * add test * add upcast_softmax to prior_transformer * enable num_images_per_prompt, add save_gif utility * add * add fast test * make style * add slow test * style * add test for img2img * refactor * enable batching * style * refactor scheduler * update test * style * attempt to solve batch related tests timeout * add doc * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * hardcode rendering related config * update betas_for_alpha_bar on ddpm_scheduler * fix copies * fix * export_to_gif * style * second attempt to speed up batching tests * add doc page to index * Remove intermediate clipping * 3rd attempt to speed up batching tests * Remvoe time index * simplify scheduler * Fix more * Fix more * fix more * make style * fix schedulers * fix some more tests * finish * add one more test * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * style * apply feedbacks * style * fix copies * add one example * style * add example for img2img * fix doc * fix more doc strings * size -> frame_size * style * update doc * style * fix on doc * update repo name * improve the usage example in shap-e img2img * add usage examples in the shap-e docs. * consolidate examples. * minor fix. * update doc * Apply suggestions from code review * Apply suggestions from code review * remove upcast * Make sure background is white * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py * Apply suggestions from code review * Finish * Apply suggestions from code review * Update src/diffusers/pipelines/shap_e/pipeline_shap_e.py * Make style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -226,6 +226,8 @@
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/spectrogram_diffusion
|
||||
title: Spectrogram Diffusion
|
||||
- sections:
|
||||
|
||||
139
docs/source/en/api/pipelines/shap_e.mdx
Normal file
139
docs/source/en/api/pipelines/shap_e.mdx
Normal file
@@ -0,0 +1,139 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Shap-E
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
The Shap-E model was proposed in [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463) by Alex Nichol and Heewon Jun from [OpenAI](https://github.com/openai).
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*We present Shap-E, a conditional generative model for 3D assets. Unlike recent work on 3D generative models which produce a single output representation, Shap-E directly generates the parameters of implicit functions that can be rendered as both textured meshes and neural radiance fields. We train Shap-E in two stages: first, we train an encoder that deterministically maps 3D assets into the parameters of an implicit function; second, we train a conditional diffusion model on outputs of the encoder. When trained on a large dataset of paired 3D and text data, our resulting models are capable of generating complex and diverse 3D assets in a matter of seconds. When compared to Point-E, an explicit generative model over point clouds, Shap-E converges faster and reaches comparable or better sample quality despite modeling a higher-dimensional, multi-representation output space.*
|
||||
|
||||
The original codebase can be found [here](https://github.com/openai/shap-e).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks |
|
||||
|---|---|
|
||||
| [pipeline_shap_e.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e.py) | *Text-to-Image Generation* |
|
||||
| [pipeline_shap_e_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py) | *Image-to-Image Generation* |
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
* [`openai/shap-e`](https://huggingface.co/openai/shap-e)
|
||||
* [`openai/shap-e-img2img`](https://huggingface.co/openai/shap-e-img2img)
|
||||
|
||||
## Usage Examples
|
||||
|
||||
In the following, we will walk you through some examples of how to use Shap-E pipelines to create 3D objects in gif format.
|
||||
|
||||
### Text-to-3D image generation
|
||||
|
||||
We can use [`ShapEPipeline`] to create 3D object based on a text prompt. In this example, we will make a birthday cupcake for :firecracker: diffusers library's 1 year birthday. The workflow to use the Shap-E text-to-image pipeline is same as how you would use other text-to-image pipelines in diffusers.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
repo = "openai/shap-e"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
guidance_scale = 15.0
|
||||
prompt = ["A firecracker", "A birthday cupcake"]
|
||||
|
||||
images = pipe(
|
||||
prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=64,
|
||||
frame_size=256,
|
||||
).images
|
||||
```
|
||||
|
||||
The output of [`ShapEPipeline`] is a list of lists of images frames. Each list of frames can be used to create a 3D object. Let's use the `export_to_gif` utility function in diffusers to make a 3D cupcake!
|
||||
|
||||
```python
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
export_to_gif(images[0], "firecracker_3d.gif")
|
||||
export_to_gif(images[1], "cake_3d.gif")
|
||||
```
|
||||

|
||||

|
||||
|
||||
|
||||
### Image-to-Image generation
|
||||
|
||||
You can use [`ShapEImg2ImgPipeline`] along with other text-to-image pipelines in diffusers and turn your 2D generation into 3D.
|
||||
|
||||
In this example, We will first genrate a cheeseburger with a simple prompt "A cheeseburger, white background"
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16)
|
||||
pipe_prior.to("cuda")
|
||||
|
||||
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
|
||||
t2i_pipe.to("cuda")
|
||||
|
||||
prompt = "A cheeseburger, white background"
|
||||
|
||||
image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
|
||||
image = t2i_pipe(
|
||||
prompt,
|
||||
image_embeds=image_embeds,
|
||||
negative_image_embeds=negative_image_embeds,
|
||||
).images[0]
|
||||
|
||||
image.save("burger.png")
|
||||
```
|
||||
|
||||

|
||||
|
||||
we will then use the Shap-E image-to-image pipeline to turn it into a 3D cheeseburger :)
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
repo = "openai/shap-e-img2img"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
guidance_scale = 3.0
|
||||
image = Image.open("burger.png").resize((256, 256))
|
||||
|
||||
images = pipe(
|
||||
image,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=64,
|
||||
frame_size=256,
|
||||
).images
|
||||
|
||||
gif_path = export_to_gif(images[0], "burger_3d.gif")
|
||||
```
|
||||

|
||||
|
||||
## ShapEPipeline
|
||||
[[autodoc]] ShapEPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ShapEImg2ImgPipeline
|
||||
[[autodoc]] ShapEImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
594
scripts/convert_shap_e_to_diffusers.py
Normal file
594
scripts/convert_shap_e_to_diffusers.py
Normal file
@@ -0,0 +1,594 @@
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.pipelines.shap_e import ShapERenderer
|
||||
|
||||
|
||||
"""
|
||||
Example - From the diffusers root directory:
|
||||
|
||||
Download weights:
|
||||
```sh
|
||||
$ wget "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt"
|
||||
```
|
||||
|
||||
Convert the model:
|
||||
```sh
|
||||
$ python scripts/convert_shap_e_to_diffusers.py \
|
||||
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
|
||||
--prior_image_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/image_cond.pt \
|
||||
--transmitter_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\
|
||||
--dump_path /home/yiyi_huggingface_co/model_repo/shap-e/renderer\
|
||||
--debug renderer
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# prior
|
||||
|
||||
PRIOR_ORIGINAL_PREFIX = "wrapped"
|
||||
|
||||
PRIOR_CONFIG = {
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 1024 // 16,
|
||||
"num_layers": 24,
|
||||
"embedding_dim": 1024,
|
||||
"num_embeddings": 1024,
|
||||
"additional_embeddings": 0,
|
||||
"time_embed_act_fn": "gelu",
|
||||
"norm_in_type": "layer",
|
||||
"encoder_hid_proj_type": None,
|
||||
"added_emb_type": None,
|
||||
"time_embed_dim": 1024 * 4,
|
||||
"embedding_proj_dim": 768,
|
||||
"clip_embed_dim": 1024 * 2,
|
||||
}
|
||||
|
||||
|
||||
def prior_model_from_original_config():
|
||||
model = PriorTransformer(**PRIOR_CONFIG)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# <original>.time_embed.c_fc -> <diffusers>.time_embedding.linear_1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.weight"],
|
||||
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.time_embed.c_proj -> <diffusers>.time_embedding.linear_2
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.weight"],
|
||||
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.input_proj -> <diffusers>.proj_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.weight"],
|
||||
"proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.clip_emb -> <diffusers>.embedding_proj
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.weight"],
|
||||
"embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.pos_emb -> <diffusers>.positional_embedding
|
||||
diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.pos_emb"][None, :]})
|
||||
|
||||
# <original>.ln_pre -> <diffusers>.norm_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.weight"],
|
||||
"norm_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.backbone.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
|
||||
for idx in range(len(model.transformer_blocks)):
|
||||
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
|
||||
original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.backbone.resblocks.{idx}"
|
||||
|
||||
# <original>.attn -> <diffusers>.attn1
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
|
||||
original_attention_prefix = f"{original_transformer_prefix}.attn"
|
||||
diffusers_checkpoint.update(
|
||||
prior_attention_to_diffusers(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
original_attention_prefix=original_attention_prefix,
|
||||
attention_head_dim=model.attention_head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.mlp -> <diffusers>.ff
|
||||
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
|
||||
original_ff_prefix = f"{original_transformer_prefix}.mlp"
|
||||
diffusers_checkpoint.update(
|
||||
prior_ff_to_diffusers(
|
||||
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.ln_1 -> <diffusers>.norm1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_1.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.ln_2 -> <diffusers>.norm3
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_2.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.ln_post -> <diffusers>.norm_out
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.weight"],
|
||||
"norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.output_proj -> <diffusers>.proj_to_clip_embeddings
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.weight"],
|
||||
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def prior_attention_to_diffusers(
|
||||
checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
|
||||
):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
|
||||
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
|
||||
weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
|
||||
bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
|
||||
split=3,
|
||||
chunk_size=attention_head_dim,
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
|
||||
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
|
||||
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
|
||||
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
|
||||
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
|
||||
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.c_proj -> <diffusers>.to_out.0
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
|
||||
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
|
||||
diffusers_checkpoint = {
|
||||
# <original>.c_fc -> <diffusers>.net.0.proj
|
||||
f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
|
||||
f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
|
||||
# <original>.c_proj -> <diffusers>.net.2
|
||||
f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
|
||||
f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
|
||||
}
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# done prior
|
||||
|
||||
|
||||
# prior_image (only slightly different from prior)
|
||||
|
||||
|
||||
PRIOR_IMAGE_ORIGINAL_PREFIX = "wrapped"
|
||||
|
||||
# Uses default arguments
|
||||
PRIOR_IMAGE_CONFIG = {
|
||||
"num_attention_heads": 8,
|
||||
"attention_head_dim": 1024 // 8,
|
||||
"num_layers": 24,
|
||||
"embedding_dim": 1024,
|
||||
"num_embeddings": 1024,
|
||||
"additional_embeddings": 0,
|
||||
"time_embed_act_fn": "gelu",
|
||||
"norm_in_type": "layer",
|
||||
"embedding_proj_norm_type": "layer",
|
||||
"encoder_hid_proj_type": None,
|
||||
"added_emb_type": None,
|
||||
"time_embed_dim": 1024 * 4,
|
||||
"embedding_proj_dim": 1024,
|
||||
"clip_embed_dim": 1024 * 2,
|
||||
}
|
||||
|
||||
|
||||
def prior_image_model_from_original_config():
|
||||
model = PriorTransformer(**PRIOR_IMAGE_CONFIG)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def prior_image_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# <original>.time_embed.c_fc -> <diffusers>.time_embedding.linear_1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.weight"],
|
||||
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.time_embed.c_proj -> <diffusers>.time_embedding.linear_2
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.weight"],
|
||||
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.input_proj -> <diffusers>.proj_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.weight"],
|
||||
"proj_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.clip_embed.0 -> <diffusers>.embedding_proj_norm
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"embedding_proj_norm.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.weight"],
|
||||
"embedding_proj_norm.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>..clip_embed.1 -> <diffusers>.embedding_proj
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"embedding_proj.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.weight"],
|
||||
"embedding_proj.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.pos_emb -> <diffusers>.positional_embedding
|
||||
diffusers_checkpoint.update(
|
||||
{"positional_embedding": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.pos_emb"][None, :]}
|
||||
)
|
||||
|
||||
# <original>.ln_pre -> <diffusers>.norm_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.weight"],
|
||||
"norm_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.backbone.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
|
||||
for idx in range(len(model.transformer_blocks)):
|
||||
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
|
||||
original_transformer_prefix = f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.backbone.resblocks.{idx}"
|
||||
|
||||
# <original>.attn -> <diffusers>.attn1
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
|
||||
original_attention_prefix = f"{original_transformer_prefix}.attn"
|
||||
diffusers_checkpoint.update(
|
||||
prior_attention_to_diffusers(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
original_attention_prefix=original_attention_prefix,
|
||||
attention_head_dim=model.attention_head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.mlp -> <diffusers>.ff
|
||||
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
|
||||
original_ff_prefix = f"{original_transformer_prefix}.mlp"
|
||||
diffusers_checkpoint.update(
|
||||
prior_ff_to_diffusers(
|
||||
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# <original>.ln_1 -> <diffusers>.norm1
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_1.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.ln_2 -> <diffusers>.norm3
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
|
||||
f"{original_transformer_prefix}.ln_2.weight"
|
||||
],
|
||||
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.ln_post -> <diffusers>.norm_out
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"norm_out.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.weight"],
|
||||
"norm_out.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# <original>.output_proj -> <diffusers>.proj_to_clip_embeddings
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.weight"],
|
||||
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# done prior_image
|
||||
|
||||
|
||||
# renderer
|
||||
|
||||
RENDERER_CONFIG = {}
|
||||
|
||||
|
||||
def renderer_model_from_original_config():
|
||||
model = ShapERenderer(**RENDERER_CONFIG)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
RENDERER_MLP_ORIGINAL_PREFIX = "renderer.nerstf"
|
||||
|
||||
RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX = "encoder.params_proj"
|
||||
|
||||
|
||||
def renderer_model_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
diffusers_checkpoint.update(
|
||||
{f"mlp.{k}": checkpoint[f"{RENDERER_MLP_ORIGINAL_PREFIX}.{k}"] for k in model.mlp.state_dict().keys()}
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"params_proj.{k}": checkpoint[f"{RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX}.{k}"]
|
||||
for k in model.params_proj.state_dict().keys()
|
||||
}
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update({"void.background": torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)})
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# done renderer
|
||||
|
||||
|
||||
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
|
||||
def split_attentions(*, weight, bias, split, chunk_size):
|
||||
weights = [None] * split
|
||||
biases = [None] * split
|
||||
|
||||
weights_biases_idx = 0
|
||||
|
||||
for starting_row_index in range(0, weight.shape[0], chunk_size):
|
||||
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
|
||||
|
||||
weight_rows = weight[row_indices, :]
|
||||
bias_rows = bias[row_indices]
|
||||
|
||||
if weights[weights_biases_idx] is None:
|
||||
assert weights[weights_biases_idx] is None
|
||||
weights[weights_biases_idx] = weight_rows
|
||||
biases[weights_biases_idx] = bias_rows
|
||||
else:
|
||||
assert weights[weights_biases_idx] is not None
|
||||
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
|
||||
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
|
||||
|
||||
weights_biases_idx = (weights_biases_idx + 1) % split
|
||||
|
||||
return weights, biases
|
||||
|
||||
|
||||
# done unet utils
|
||||
|
||||
|
||||
# Driver functions
|
||||
|
||||
|
||||
def prior(*, args, checkpoint_map_location):
|
||||
print("loading prior")
|
||||
|
||||
prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
prior_model = prior_model_from_original_config()
|
||||
|
||||
prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint)
|
||||
|
||||
del prior_checkpoint
|
||||
|
||||
load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model)
|
||||
|
||||
print("done loading prior")
|
||||
|
||||
return prior_model
|
||||
|
||||
|
||||
def prior_image(*, args, checkpoint_map_location):
|
||||
print("loading prior_image")
|
||||
|
||||
print(f"load checkpoint from {args.prior_image_checkpoint_path}")
|
||||
prior_checkpoint = torch.load(args.prior_image_checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
prior_model = prior_image_model_from_original_config()
|
||||
|
||||
prior_diffusers_checkpoint = prior_image_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint)
|
||||
|
||||
del prior_checkpoint
|
||||
|
||||
load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model)
|
||||
|
||||
print("done loading prior_image")
|
||||
|
||||
return prior_model
|
||||
|
||||
|
||||
def renderer(*, args, checkpoint_map_location):
|
||||
print(" loading renderer")
|
||||
|
||||
renderer_checkpoint = torch.load(args.transmitter_checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
renderer_model = renderer_model_from_original_config()
|
||||
|
||||
renderer_diffusers_checkpoint = renderer_model_original_checkpoint_to_diffusers_checkpoint(
|
||||
renderer_model, renderer_checkpoint
|
||||
)
|
||||
|
||||
del renderer_checkpoint
|
||||
|
||||
load_checkpoint_to_model(renderer_diffusers_checkpoint, renderer_model, strict=True)
|
||||
|
||||
print("done loading renderer")
|
||||
|
||||
return renderer_model
|
||||
|
||||
|
||||
# prior model will expect clip_mean and clip_std, whic are missing from the state_dict
|
||||
PRIOR_EXPECTED_MISSING_KEYS = ["clip_mean", "clip_std"]
|
||||
|
||||
|
||||
def load_prior_checkpoint_to_model(checkpoint, model):
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
torch.save(checkpoint, file.name)
|
||||
del checkpoint
|
||||
missing_keys, unexpected_keys = model.load_state_dict(torch.load(file.name), strict=False)
|
||||
missing_keys = list(set(missing_keys) - set(PRIOR_EXPECTED_MISSING_KEYS))
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
raise ValueError(f"Unexpected keys when loading prior model: {unexpected_keys}")
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(f"Missing keys when loading prior model: {missing_keys}")
|
||||
|
||||
|
||||
def load_checkpoint_to_model(checkpoint, model, strict=False):
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
torch.save(checkpoint, file.name)
|
||||
del checkpoint
|
||||
if strict:
|
||||
model.load_state_dict(torch.load(file.name), strict=True)
|
||||
else:
|
||||
load_checkpoint_and_dispatch(model, file.name, device_map="auto")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--prior_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the prior checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prior_image_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the prior_image checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transmitter_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the transmitter checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_load_device",
|
||||
default="cpu",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The device passed to `map_location` when loading checkpoints.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Only run a specific stage of the convert script. Used for debugging",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"loading checkpoints to {args.checkpoint_load_device}")
|
||||
|
||||
checkpoint_map_location = torch.device(args.checkpoint_load_device)
|
||||
|
||||
if args.debug is not None:
|
||||
print(f"debug: only executing {args.debug}")
|
||||
|
||||
if args.debug is None:
|
||||
print("YiYi TO-DO")
|
||||
elif args.debug == "prior":
|
||||
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
prior_model.save_pretrained(args.dump_path)
|
||||
elif args.debug == "prior_image":
|
||||
prior_model = prior_image(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
prior_model.save_pretrained(args.dump_path)
|
||||
elif args.debug == "renderer":
|
||||
renderer_model = renderer(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
renderer_model.save_pretrained(args.dump_path)
|
||||
else:
|
||||
raise ValueError(f"unknown debug value : {args.debug}")
|
||||
@@ -149,6 +149,8 @@ else:
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
|
||||
@@ -34,14 +34,33 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768):
|
||||
The dimension of the CLIP embeddings. Image embeddings and text embeddings are both the same dimension.
|
||||
num_embeddings (`int`, *optional*, defaults to 77): The max number of CLIP embeddings allowed (the
|
||||
length of the prompt after it has been tokenized).
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
||||
num_embeddings (`int`, *optional*, defaults to 77):
|
||||
The number of embeddings of the model input `hidden_states`
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
||||
The activation function to use to create timestep embeddings.
|
||||
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
||||
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
||||
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
||||
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
||||
needed.
|
||||
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
||||
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
||||
`encoder_hidden_states` is `None`.
|
||||
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
||||
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
||||
product between the text embedding and image embedding as proposed in the unclip paper
|
||||
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
||||
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
||||
If None, will be set to `num_attention_heads * attention_head_dim`
|
||||
embedding_proj_dim (`int`, *optional*, default to None):
|
||||
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
||||
clip_embed_dim (`int`, *optional*, default to None):
|
||||
The dimension of the output. If None, will be set to `embedding_dim`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -54,6 +73,14 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
time_embed_act_fn: str = "silu",
|
||||
norm_in_type: Optional[str] = None, # layer
|
||||
embedding_proj_norm_type: Optional[str] = None, # layer
|
||||
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
||||
added_emb_type: Optional[str] = "prd", # prd
|
||||
time_embed_dim: Optional[int] = None,
|
||||
embedding_proj_dim: Optional[int] = None,
|
||||
clip_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -61,17 +88,41 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
time_embed_dim = time_embed_dim or inner_dim
|
||||
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
||||
clip_embed_dim = clip_embed_dim or embedding_dim
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
if embedding_proj_norm_type is None:
|
||||
self.embedding_proj_norm = None
|
||||
elif embedding_proj_norm_type == "layer":
|
||||
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
||||
|
||||
if encoder_hid_proj_type is None:
|
||||
self.encoder_hidden_states_proj = None
|
||||
elif encoder_hid_proj_type == "linear":
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
if added_emb_type == "prd":
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
elif added_emb_type is None:
|
||||
self.prd_embedding = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
@@ -87,8 +138,16 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
]
|
||||
)
|
||||
|
||||
if norm_in_type == "layer":
|
||||
self.norm_in = nn.LayerNorm(inner_dim)
|
||||
elif norm_in_type is None:
|
||||
self.norm_in = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
|
||||
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||
@@ -97,8 +156,8 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
@@ -172,7 +231,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
@@ -217,23 +276,61 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
if self.embedding_proj_norm is not None:
|
||||
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
||||
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
additional_embeds = []
|
||||
additional_embeddings_len = 0
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
additional_embeds.append(encoder_hidden_states)
|
||||
additional_embeddings_len += encoder_hidden_states.shape[1]
|
||||
|
||||
if len(proj_embeddings.shape) == 2:
|
||||
proj_embeddings = proj_embeddings[:, None, :]
|
||||
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states[:, None, :]
|
||||
|
||||
additional_embeds = additional_embeds + [
|
||||
proj_embeddings,
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states,
|
||||
]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
additional_embeds.append(prd_embedding)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
encoder_hidden_states,
|
||||
proj_embeddings[:, None, :],
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states[:, None, :],
|
||||
prd_embedding,
|
||||
],
|
||||
additional_embeds,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
||||
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
||||
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
||||
positional_embeddings = F.pad(
|
||||
positional_embeddings,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
additional_embeddings_len,
|
||||
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
||||
),
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -242,11 +339,19 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = hidden_states[:, -1]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
else:
|
||||
hidden_states = hidden_states[:, additional_embeddings_len:]
|
||||
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -77,6 +77,7 @@ else:
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
|
||||
27
src/diffusers/pipelines/shap_e/__init__.py
Normal file
27
src/diffusers/pipelines/shap_e/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
|
||||
else:
|
||||
from .camera import create_pan_cameras
|
||||
from .pipeline_shap_e import ShapEPipeline
|
||||
from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline
|
||||
from .renderer import (
|
||||
BoundingBoxVolume,
|
||||
ImportanceRaySampler,
|
||||
MLPNeRFModelOutput,
|
||||
MLPNeRSTFModel,
|
||||
ShapEParamsProjModel,
|
||||
ShapERenderer,
|
||||
StratifiedRaySampler,
|
||||
VoidNeRFModel,
|
||||
)
|
||||
147
src/diffusers/pipelines/shap_e/camera.py
Normal file
147
src/diffusers/pipelines/shap_e/camera.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DifferentiableProjectiveCamera:
|
||||
"""
|
||||
Implements a batch, differentiable, standard pinhole camera
|
||||
"""
|
||||
|
||||
origin: torch.Tensor # [batch_size x 3]
|
||||
x: torch.Tensor # [batch_size x 3]
|
||||
y: torch.Tensor # [batch_size x 3]
|
||||
z: torch.Tensor # [batch_size x 3]
|
||||
width: int
|
||||
height: int
|
||||
x_fov: float
|
||||
y_fov: float
|
||||
shape: Tuple[int]
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
|
||||
assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
|
||||
assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2
|
||||
|
||||
def resolution(self):
|
||||
return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
|
||||
|
||||
def fov(self):
|
||||
return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
|
||||
|
||||
def get_image_coords(self) -> torch.Tensor:
|
||||
"""
|
||||
:return: coords of shape (width * height, 2)
|
||||
"""
|
||||
pixel_indices = torch.arange(self.height * self.width)
|
||||
coords = torch.stack(
|
||||
[
|
||||
pixel_indices % self.width,
|
||||
torch.div(pixel_indices, self.width, rounding_mode="trunc"),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
return coords
|
||||
|
||||
@property
|
||||
def camera_rays(self):
|
||||
batch_size, *inner_shape = self.shape
|
||||
inner_batch_size = int(np.prod(inner_shape))
|
||||
|
||||
coords = self.get_image_coords()
|
||||
coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
|
||||
rays = self.get_camera_rays(coords)
|
||||
|
||||
rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3)
|
||||
|
||||
return rays
|
||||
|
||||
def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, *shape, n_coords = coords.shape
|
||||
assert n_coords == 2
|
||||
assert batch_size == self.origin.shape[0]
|
||||
|
||||
flat = coords.view(batch_size, -1, 2)
|
||||
|
||||
res = self.resolution()
|
||||
fov = self.fov()
|
||||
|
||||
fracs = (flat.float() / (res - 1)) * 2 - 1
|
||||
fracs = fracs * torch.tan(fov / 2)
|
||||
|
||||
fracs = fracs.view(batch_size, -1, 2)
|
||||
directions = (
|
||||
self.z.view(batch_size, 1, 3)
|
||||
+ self.x.view(batch_size, 1, 3) * fracs[:, :, :1]
|
||||
+ self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]
|
||||
)
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True)
|
||||
rays = torch.stack(
|
||||
[
|
||||
torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]),
|
||||
directions,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
return rays.view(batch_size, *shape, 2, 3)
|
||||
|
||||
def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
|
||||
"""
|
||||
Creates a new camera for the resized view assuming the aspect ratio does not change.
|
||||
"""
|
||||
assert width * self.height == height * self.width, "The aspect ratio should not change."
|
||||
return DifferentiableProjectiveCamera(
|
||||
origin=self.origin,
|
||||
x=self.x,
|
||||
y=self.y,
|
||||
z=self.z,
|
||||
width=width,
|
||||
height=height,
|
||||
x_fov=self.x_fov,
|
||||
y_fov=self.y_fov,
|
||||
)
|
||||
|
||||
|
||||
def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera:
|
||||
origins = []
|
||||
xs = []
|
||||
ys = []
|
||||
zs = []
|
||||
for theta in np.linspace(0, 2 * np.pi, num=20):
|
||||
z = np.array([np.sin(theta), np.cos(theta), -0.5])
|
||||
z /= np.sqrt(np.sum(z**2))
|
||||
origin = -z * 4
|
||||
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
|
||||
y = np.cross(z, x)
|
||||
origins.append(origin)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
zs.append(z)
|
||||
return DifferentiableProjectiveCamera(
|
||||
origin=torch.from_numpy(np.stack(origins, axis=0)).float(),
|
||||
x=torch.from_numpy(np.stack(xs, axis=0)).float(),
|
||||
y=torch.from_numpy(np.stack(ys, axis=0)).float(),
|
||||
z=torch.from_numpy(np.stack(zs, axis=0)).float(),
|
||||
width=size,
|
||||
height=size,
|
||||
x_fov=0.7,
|
||||
y_fov=0.7,
|
||||
shape=(1, len(xs)),
|
||||
)
|
||||
390
src/diffusers/pipelines/shap_e/pipeline_shap_e.py
Normal file
390
src/diffusers/pipelines/shap_e/pipeline_shap_e.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import HeunDiscreteScheduler
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from .renderer import ShapERenderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
>>> from diffusers.utils import export_to_gif
|
||||
|
||||
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
>>> repo = "openai/shap-e"
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to(device)
|
||||
|
||||
>>> guidance_scale = 15.0
|
||||
>>> prompt = "a shark"
|
||||
|
||||
>>> images = pipe(
|
||||
... prompt,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=64,
|
||||
... frame_size=256,
|
||||
... ).images
|
||||
|
||||
>>> gif_path = export_to_gif(images[0], "shark_3d.gif")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShapEPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for ShapEPipeline.
|
||||
|
||||
Args:
|
||||
images (`torch.FloatTensor`)
|
||||
a list of images for 3D rendering
|
||||
"""
|
||||
|
||||
images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]]
|
||||
|
||||
|
||||
class ShapEPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
prior ([`PriorTransformer`]):
|
||||
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`HeunDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
renderer ([`ShapERenderer`]):
|
||||
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
|
||||
with the NeRF rendering method
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
scheduler: HeunDiscreteScheduler,
|
||||
renderer: ShapERenderer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
renderer=renderer,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [self.text_encoder, self.prior]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.text_encoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file
|
||||
self.tokenizer.pad_token_id = 0
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
prompt_embeds = text_encoder_output.text_embeds
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
# in Shap-E it normalize the prompt_embeds and then later rescale it
|
||||
prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# Rescale the features to have unit variance
|
||||
prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
num_images_per_prompt: int = 1,
|
||||
num_inference_steps: int = 25,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
frame_size: int = 64,
|
||||
output_type: Optional[str] = "pil", # pil, np, latent
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
frame_size (`int`, *optional*, default to 64):
|
||||
the width and height of each image frame of the generated 3d output
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`ShapEPipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
|
||||
|
||||
# prior
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
num_embeddings = self.prior.config.num_embeddings
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
|
||||
latents = self.prepare_latents(
|
||||
(batch_size, num_embeddings * embedding_dim),
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
self.scheduler,
|
||||
)
|
||||
|
||||
# YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
|
||||
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
noise_pred = self.prior(
|
||||
scaled_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=prompt_embeds,
|
||||
).predicted_image_embedding
|
||||
|
||||
# remove the variance
|
||||
noise_pred, _ = noise_pred.split(
|
||||
scaled_model_input.shape[2], dim=2
|
||||
) # batch_size, num_embeddings, embedding_dim
|
||||
|
||||
if do_classifier_free_guidance is not None:
|
||||
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
timestep=t,
|
||||
sample=latents,
|
||||
).prev_sample
|
||||
|
||||
if output_type == "latent":
|
||||
return ShapEPipelineOutput(images=latents)
|
||||
|
||||
images = []
|
||||
for i, latent in enumerate(latents):
|
||||
image = self.renderer.decode(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
ray_batch_size=4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
)
|
||||
images.append(image)
|
||||
|
||||
images = torch.stack(images)
|
||||
|
||||
if output_type not in ["np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
images = images.cpu().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (images,)
|
||||
|
||||
return ShapEPipelineOutput(images=images)
|
||||
349
src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
Normal file
349
src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import HeunDiscreteScheduler
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from .renderer import ShapERenderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from PIL import Image
|
||||
>>> import torch
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
>>> from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
>>> repo = "openai/shap-e-img2img"
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to(device)
|
||||
|
||||
>>> guidance_scale = 3.0
|
||||
>>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png"
|
||||
>>> image = load_image(image_url).convert("RGB")
|
||||
|
||||
>>> images = pipe(
|
||||
... image,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=64,
|
||||
... frame_size=256,
|
||||
... ).images
|
||||
|
||||
>>> gif_path = export_to_gif(images[0], "corgi_3d.gif")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShapEPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for ShapEPipeline.
|
||||
|
||||
Args:
|
||||
images (`torch.FloatTensor`)
|
||||
a list of images for 3D rendering
|
||||
"""
|
||||
|
||||
images: Union[PIL.Image.Image, np.ndarray]
|
||||
|
||||
|
||||
class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
prior ([`PriorTransformer`]):
|
||||
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`HeunDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
renderer ([`ShapERenderer`]):
|
||||
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
|
||||
with the NeRF rendering method
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_processor: CLIPImageProcessor,
|
||||
scheduler: HeunDiscreteScheduler,
|
||||
renderer: ShapERenderer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
scheduler=scheduler,
|
||||
renderer=renderer,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [self.image_encoder, self.prior]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.image_encoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.image_encoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_image(
|
||||
self,
|
||||
image,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
if isinstance(image, List) and isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0)
|
||||
|
||||
image = image.to(dtype=self.image_encoder.dtype, device=device)
|
||||
|
||||
image_embeds = self.image_encoder(image)["last_hidden_state"]
|
||||
image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256
|
||||
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
return image_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, List[PIL.Image.Image]],
|
||||
num_images_per_prompt: int = 1,
|
||||
num_inference_steps: int = 25,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
frame_size: int = 64,
|
||||
output_type: Optional[str] = "pil", # pil, np, latent
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
frame_size (`int`, *optional*, default to 64):
|
||||
the width and height of each image frame of the generated 3d output
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`ShapEPipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
batch_size = 1
|
||||
elif isinstance(image, torch.Tensor):
|
||||
batch_size = image.shape[0]
|
||||
elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)):
|
||||
batch_size = len(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}"
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
|
||||
|
||||
# prior
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
num_embeddings = self.prior.config.num_embeddings
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
|
||||
latents = self.prepare_latents(
|
||||
(batch_size, num_embeddings * embedding_dim),
|
||||
image_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
self.scheduler,
|
||||
)
|
||||
|
||||
# YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
|
||||
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
noise_pred = self.prior(
|
||||
scaled_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=image_embeds,
|
||||
).predicted_image_embedding
|
||||
|
||||
# remove the variance
|
||||
noise_pred, _ = noise_pred.split(
|
||||
scaled_model_input.shape[2], dim=2
|
||||
) # batch_size, num_embeddings, embedding_dim
|
||||
|
||||
if do_classifier_free_guidance is not None:
|
||||
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
timestep=t,
|
||||
sample=latents,
|
||||
).prev_sample
|
||||
|
||||
if output_type == "latent":
|
||||
return ShapEPipelineOutput(images=latents)
|
||||
|
||||
images = []
|
||||
for i, latent in enumerate(latents):
|
||||
print()
|
||||
image = self.renderer.decode(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
ray_batch_size=4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
)
|
||||
|
||||
images.append(image)
|
||||
|
||||
images = torch.stack(images)
|
||||
|
||||
if output_type not in ["np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
images = images.cpu().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (images,)
|
||||
|
||||
return ShapEPipelineOutput(images=images)
|
||||
709
src/diffusers/pipelines/shap_e/renderer.py
Normal file
709
src/diffusers/pipelines/shap_e/renderer.py
Normal file
@@ -0,0 +1,709 @@
|
||||
# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
from ...utils import BaseOutput
|
||||
from .camera import create_pan_cameras
|
||||
|
||||
|
||||
def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor:
|
||||
r"""
|
||||
Sample from the given discrete probability distribution with replacement.
|
||||
|
||||
The i-th bin is assumed to have mass pmf[i].
|
||||
|
||||
Args:
|
||||
pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all()
|
||||
n_samples: number of samples
|
||||
|
||||
Return:
|
||||
indices sampled with replacement
|
||||
"""
|
||||
|
||||
*shape, support_size, last_dim = pmf.shape
|
||||
assert last_dim == 1
|
||||
|
||||
cdf = torch.cumsum(pmf.view(-1, support_size), dim=1)
|
||||
inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device))
|
||||
|
||||
return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1)
|
||||
|
||||
|
||||
def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:
|
||||
"""
|
||||
Concatenate x and its positional encodings, following NeRF.
|
||||
|
||||
Reference: https://arxiv.org/pdf/2210.04628.pdf
|
||||
"""
|
||||
if min_deg == max_deg:
|
||||
return x
|
||||
|
||||
scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device)
|
||||
*shape, dim = x.shape
|
||||
xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)
|
||||
assert xb.shape[-1] == dim * (max_deg - min_deg)
|
||||
emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()
|
||||
return torch.cat([x, emb], dim=-1)
|
||||
|
||||
|
||||
def encode_position(position):
|
||||
return posenc_nerf(position, min_deg=0, max_deg=15)
|
||||
|
||||
|
||||
def encode_direction(position, direction=None):
|
||||
if direction is None:
|
||||
return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))
|
||||
else:
|
||||
return posenc_nerf(direction, min_deg=0, max_deg=8)
|
||||
|
||||
|
||||
def _sanitize_name(x: str) -> str:
|
||||
return x.replace(".", "__")
|
||||
|
||||
|
||||
def integrate_samples(volume_range, ts, density, channels):
|
||||
r"""
|
||||
Function integrating the model output.
|
||||
|
||||
Args:
|
||||
volume_range: Specifies the integral range [t0, t1]
|
||||
ts: timesteps
|
||||
density: torch.Tensor [batch_size, *shape, n_samples, 1]
|
||||
channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
|
||||
returns:
|
||||
channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density
|
||||
*transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume
|
||||
)
|
||||
"""
|
||||
|
||||
# 1. Calculate the weights
|
||||
_, _, dt = volume_range.partition(ts)
|
||||
ddensity = density * dt
|
||||
|
||||
mass = torch.cumsum(ddensity, dim=-2)
|
||||
transmittance = torch.exp(-mass[..., -1, :])
|
||||
|
||||
alphas = 1.0 - torch.exp(-ddensity)
|
||||
Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
|
||||
# This is the probability of light hitting and reflecting off of
|
||||
# something at depth [..., i, :].
|
||||
weights = alphas * Ts
|
||||
|
||||
# 2. Integrate channels
|
||||
channels = torch.sum(channels * weights, dim=-2)
|
||||
|
||||
return channels, weights, transmittance
|
||||
|
||||
|
||||
class VoidNeRFModel(nn.Module):
|
||||
"""
|
||||
Implements the default empty space model where all queries are rendered as background.
|
||||
"""
|
||||
|
||||
def __init__(self, background, channel_scale=255.0):
|
||||
super().__init__()
|
||||
background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale)
|
||||
|
||||
self.register_buffer("background", background)
|
||||
|
||||
def forward(self, position):
|
||||
background = self.background[None].to(position.device)
|
||||
|
||||
shape = position.shape[:-1]
|
||||
ones = [1] * (len(shape) - 1)
|
||||
n_channels = background.shape[-1]
|
||||
background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels])
|
||||
|
||||
return background
|
||||
|
||||
|
||||
@dataclass
|
||||
class VolumeRange:
|
||||
t0: torch.Tensor
|
||||
t1: torch.Tensor
|
||||
intersected: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.t0.shape == self.t1.shape == self.intersected.shape
|
||||
|
||||
def partition(self, ts):
|
||||
"""
|
||||
Partitions t0 and t1 into n_samples intervals.
|
||||
|
||||
Args:
|
||||
ts: [batch_size, *shape, n_samples, 1]
|
||||
|
||||
Return:
|
||||
|
||||
lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size,
|
||||
*shape, n_samples, 1]
|
||||
|
||||
where
|
||||
ts \\in [lower, upper] deltas = upper - lower
|
||||
"""
|
||||
|
||||
mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5
|
||||
lower = torch.cat([self.t0[..., None, :], mids], dim=-2)
|
||||
upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)
|
||||
delta = upper - lower
|
||||
assert lower.shape == upper.shape == delta.shape == ts.shape
|
||||
return lower, upper, delta
|
||||
|
||||
|
||||
class BoundingBoxVolume(nn.Module):
|
||||
"""
|
||||
Axis-aligned bounding box defined by the two opposite corners.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bbox_min,
|
||||
bbox_max,
|
||||
min_dist: float = 0.0,
|
||||
min_t_range: float = 1e-3,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
bbox_min: the left/bottommost corner of the bounding box
|
||||
bbox_max: the other corner of the bounding box
|
||||
min_dist: all rays should start at least this distance away from the origin.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.min_dist = min_dist
|
||||
self.min_t_range = min_t_range
|
||||
|
||||
self.bbox_min = torch.tensor(bbox_min)
|
||||
self.bbox_max = torch.tensor(bbox_max)
|
||||
self.bbox = torch.stack([self.bbox_min, self.bbox_max])
|
||||
assert self.bbox.shape == (2, 3)
|
||||
assert min_dist >= 0.0
|
||||
assert min_t_range > 0.0
|
||||
|
||||
def intersect(
|
||||
self,
|
||||
origin: torch.Tensor,
|
||||
direction: torch.Tensor,
|
||||
t0_lower: Optional[torch.Tensor] = None,
|
||||
epsilon=1e-6,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
origin: [batch_size, *shape, 3]
|
||||
direction: [batch_size, *shape, 3]
|
||||
t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
|
||||
params: Optional meta parameters in case Volume is parametric
|
||||
epsilon: to stabilize calculations
|
||||
|
||||
Return:
|
||||
A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with
|
||||
the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to
|
||||
be on the boundary of the volume.
|
||||
"""
|
||||
|
||||
batch_size, *shape, _ = origin.shape
|
||||
ones = [1] * len(shape)
|
||||
bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device)
|
||||
|
||||
def _safe_divide(a, b, epsilon=1e-6):
|
||||
return a / torch.where(b < 0, b - epsilon, b + epsilon)
|
||||
|
||||
ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)
|
||||
|
||||
# Cases to think about:
|
||||
#
|
||||
# 1. t1 <= t0: the ray does not pass through the AABB.
|
||||
# 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin.
|
||||
# 3. t0 <= 0 <= t1: the ray starts from inside the BB
|
||||
# 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice.
|
||||
#
|
||||
# 1 and 4 are clearly handled from t0 < t1 below.
|
||||
# Making t0 at least min_dist (>= 0) takes care of 2 and 3.
|
||||
t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)
|
||||
t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values
|
||||
assert t0.shape == t1.shape == (batch_size, *shape, 1)
|
||||
if t0_lower is not None:
|
||||
assert t0.shape == t0_lower.shape
|
||||
t0 = torch.maximum(t0, t0_lower)
|
||||
|
||||
intersected = t0 + self.min_t_range < t1
|
||||
t0 = torch.where(intersected, t0, torch.zeros_like(t0))
|
||||
t1 = torch.where(intersected, t1, torch.ones_like(t1))
|
||||
|
||||
return VolumeRange(t0=t0, t1=t1, intersected=intersected)
|
||||
|
||||
|
||||
class StratifiedRaySampler(nn.Module):
|
||||
"""
|
||||
Instead of fixed intervals, a sample is drawn uniformly at random from each interval.
|
||||
"""
|
||||
|
||||
def __init__(self, depth_mode: str = "linear"):
|
||||
"""
|
||||
:param depth_mode: linear samples ts linearly in depth. harmonic ensures
|
||||
closer points are sampled more densely.
|
||||
"""
|
||||
self.depth_mode = depth_mode
|
||||
assert self.depth_mode in ("linear", "geometric", "harmonic")
|
||||
|
||||
def sample(
|
||||
self,
|
||||
t0: torch.Tensor,
|
||||
t1: torch.Tensor,
|
||||
n_samples: int,
|
||||
epsilon: float = 1e-3,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
t0: start time has shape [batch_size, *shape, 1]
|
||||
t1: finish time has shape [batch_size, *shape, 1]
|
||||
n_samples: number of ts to sample
|
||||
Return:
|
||||
sampled ts of shape [batch_size, *shape, n_samples, 1]
|
||||
"""
|
||||
ones = [1] * (len(t0.shape) - 1)
|
||||
ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)
|
||||
|
||||
if self.depth_mode == "linear":
|
||||
ts = t0 * (1.0 - ts) + t1 * ts
|
||||
elif self.depth_mode == "geometric":
|
||||
ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
|
||||
elif self.depth_mode == "harmonic":
|
||||
# The original NeRF recommends this interpolation scheme for
|
||||
# spherical scenes, but there could be some weird edge cases when
|
||||
# the observer crosses from the inner to outer volume.
|
||||
ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)
|
||||
|
||||
mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
|
||||
upper = torch.cat([mids, t1], dim=-1)
|
||||
lower = torch.cat([t0, mids], dim=-1)
|
||||
# yiyi notes: add a random seed here for testing, don't forget to remove
|
||||
torch.manual_seed(0)
|
||||
t_rand = torch.rand_like(ts)
|
||||
|
||||
ts = lower + (upper - lower) * t_rand
|
||||
return ts.unsqueeze(-1)
|
||||
|
||||
|
||||
class ImportanceRaySampler(nn.Module):
|
||||
"""
|
||||
Given the initial estimate of densities, this samples more from regions/bins expected to have objects.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
volume_range: VolumeRange,
|
||||
ts: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
blur_pool: bool = False,
|
||||
alpha: float = 1e-5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
volume_range: the range in which a ray intersects the given volume.
|
||||
ts: earlier samples from the coarse rendering step
|
||||
weights: discretized version of density * transmittance
|
||||
blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.
|
||||
alpha: small value to add to weights.
|
||||
"""
|
||||
self.volume_range = volume_range
|
||||
self.ts = ts.clone().detach()
|
||||
self.weights = weights.clone().detach()
|
||||
self.blur_pool = blur_pool
|
||||
self.alpha = alpha
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
t0: start time has shape [batch_size, *shape, 1]
|
||||
t1: finish time has shape [batch_size, *shape, 1]
|
||||
n_samples: number of ts to sample
|
||||
Return:
|
||||
sampled ts of shape [batch_size, *shape, n_samples, 1]
|
||||
"""
|
||||
lower, upper, _ = self.volume_range.partition(self.ts)
|
||||
|
||||
batch_size, *shape, n_coarse_samples, _ = self.ts.shape
|
||||
|
||||
weights = self.weights
|
||||
if self.blur_pool:
|
||||
padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
|
||||
maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
|
||||
weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
|
||||
weights = weights + self.alpha
|
||||
pmf = weights / weights.sum(dim=-2, keepdim=True)
|
||||
inds = sample_pmf(pmf, n_samples)
|
||||
assert inds.shape == (batch_size, *shape, n_samples, 1)
|
||||
assert (inds >= 0).all() and (inds < n_coarse_samples).all()
|
||||
|
||||
t_rand = torch.rand(inds.shape, device=inds.device)
|
||||
lower_ = torch.gather(lower, -2, inds)
|
||||
upper_ = torch.gather(upper, -2, inds)
|
||||
|
||||
ts = lower_ + (upper_ - lower_) * t_rand
|
||||
ts = torch.sort(ts, dim=-2).values
|
||||
return ts
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLPNeRFModelOutput(BaseOutput):
|
||||
density: torch.Tensor
|
||||
signed_distance: torch.Tensor
|
||||
channels: torch.Tensor
|
||||
ts: torch.Tensor
|
||||
|
||||
|
||||
class MLPNeRSTFModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
d_hidden: int = 256,
|
||||
n_output: int = 12,
|
||||
n_hidden_layers: int = 6,
|
||||
act_fn: str = "swish",
|
||||
insert_direction_at: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Instantiate the MLP
|
||||
|
||||
# Find out the dimension of encoded position and direction
|
||||
dummy = torch.eye(1, 3)
|
||||
d_posenc_pos = encode_position(position=dummy).shape[-1]
|
||||
d_posenc_dir = encode_direction(position=dummy).shape[-1]
|
||||
|
||||
mlp_widths = [d_hidden] * n_hidden_layers
|
||||
input_widths = [d_posenc_pos] + mlp_widths
|
||||
output_widths = mlp_widths + [n_output]
|
||||
|
||||
if insert_direction_at is not None:
|
||||
input_widths[insert_direction_at] += d_posenc_dir
|
||||
|
||||
self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)])
|
||||
|
||||
if act_fn == "swish":
|
||||
# self.activation = swish
|
||||
# yiyi testing:
|
||||
self.activation = lambda x: F.silu(x)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function {act_fn}")
|
||||
|
||||
self.sdf_activation = torch.tanh
|
||||
self.density_activation = torch.nn.functional.relu
|
||||
self.channel_activation = torch.sigmoid
|
||||
|
||||
def map_indices_to_keys(self, output):
|
||||
h_map = {
|
||||
"sdf": (0, 1),
|
||||
"density_coarse": (1, 2),
|
||||
"density_fine": (2, 3),
|
||||
"stf": (3, 6),
|
||||
"nerf_coarse": (6, 9),
|
||||
"nerf_fine": (9, 12),
|
||||
}
|
||||
|
||||
mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()}
|
||||
|
||||
return mapped_output
|
||||
|
||||
def forward(self, *, position, direction, ts, nerf_level="coarse"):
|
||||
h = encode_position(position)
|
||||
|
||||
h_preact = h
|
||||
h_directionless = None
|
||||
for i, layer in enumerate(self.mlp):
|
||||
if i == self.config.insert_direction_at: # 4 in the config
|
||||
h_directionless = h_preact
|
||||
h_direction = encode_direction(position, direction=direction)
|
||||
h = torch.cat([h, h_direction], dim=-1)
|
||||
|
||||
h = layer(h)
|
||||
|
||||
h_preact = h
|
||||
|
||||
if i < len(self.mlp) - 1:
|
||||
h = self.activation(h)
|
||||
|
||||
h_final = h
|
||||
if h_directionless is None:
|
||||
h_directionless = h_preact
|
||||
|
||||
activation = self.map_indices_to_keys(h_final)
|
||||
|
||||
if nerf_level == "coarse":
|
||||
h_density = activation["density_coarse"]
|
||||
h_channels = activation["nerf_coarse"]
|
||||
else:
|
||||
h_density = activation["density_fine"]
|
||||
h_channels = activation["nerf_fine"]
|
||||
|
||||
density = self.density_activation(h_density)
|
||||
signed_distance = self.sdf_activation(activation["sdf"])
|
||||
channels = self.channel_activation(h_channels)
|
||||
|
||||
# yiyi notes: I think signed_distance is not used
|
||||
return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts)
|
||||
|
||||
|
||||
class ChannelsProj(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vectors: int,
|
||||
channels: int,
|
||||
d_latent: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(d_latent, vectors * channels)
|
||||
self.norm = nn.LayerNorm(channels)
|
||||
self.d_latent = d_latent
|
||||
self.vectors = vectors
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_bvd = x
|
||||
w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
|
||||
b_vc = self.proj.bias.view(1, self.vectors, self.channels)
|
||||
h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
|
||||
h = self.norm(h)
|
||||
|
||||
h = h + b_vc
|
||||
return h
|
||||
|
||||
|
||||
class ShapEParamsProjModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP).
|
||||
|
||||
For more details, see the original paper:
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
param_names: Tuple[str] = (
|
||||
"nerstf.mlp.0.weight",
|
||||
"nerstf.mlp.1.weight",
|
||||
"nerstf.mlp.2.weight",
|
||||
"nerstf.mlp.3.weight",
|
||||
),
|
||||
param_shapes: Tuple[Tuple[int]] = (
|
||||
(256, 93),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
),
|
||||
d_latent: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# check inputs
|
||||
if len(param_names) != len(param_shapes):
|
||||
raise ValueError("Must provide same number of `param_names` as `param_shapes`")
|
||||
self.projections = nn.ModuleDict({})
|
||||
for k, (vectors, channels) in zip(param_names, param_shapes):
|
||||
self.projections[_sanitize_name(k)] = ChannelsProj(
|
||||
vectors=vectors,
|
||||
channels=channels,
|
||||
d_latent=d_latent,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
out = {}
|
||||
start = 0
|
||||
for k, shape in zip(self.config.param_names, self.config.param_shapes):
|
||||
vectors, _ = shape
|
||||
end = start + vectors
|
||||
x_bvd = x[:, start:end]
|
||||
out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
|
||||
start = end
|
||||
return out
|
||||
|
||||
|
||||
class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
param_names: Tuple[str] = (
|
||||
"nerstf.mlp.0.weight",
|
||||
"nerstf.mlp.1.weight",
|
||||
"nerstf.mlp.2.weight",
|
||||
"nerstf.mlp.3.weight",
|
||||
),
|
||||
param_shapes: Tuple[Tuple[int]] = (
|
||||
(256, 93),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
(256, 256),
|
||||
),
|
||||
d_latent: int = 1024,
|
||||
d_hidden: int = 256,
|
||||
n_output: int = 12,
|
||||
n_hidden_layers: int = 6,
|
||||
act_fn: str = "swish",
|
||||
insert_direction_at: int = 4,
|
||||
background: Tuple[float] = (
|
||||
255.0,
|
||||
255.0,
|
||||
255.0,
|
||||
),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.params_proj = ShapEParamsProjModel(
|
||||
param_names=param_names,
|
||||
param_shapes=param_shapes,
|
||||
d_latent=d_latent,
|
||||
)
|
||||
self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
|
||||
self.void = VoidNeRFModel(background=background, channel_scale=255.0)
|
||||
self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
|
||||
|
||||
@torch.no_grad()
|
||||
def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False):
|
||||
"""
|
||||
Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below
|
||||
with some abuse of notations)
|
||||
|
||||
C(r) := sum(
|
||||
transmittance(t[i]) * integrate(
|
||||
lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]],
|
||||
) for i in range(len(parts))
|
||||
) + transmittance(t[-1]) * void_model(t[-1]).channels
|
||||
|
||||
where
|
||||
|
||||
1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through
|
||||
the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are
|
||||
obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t
|
||||
where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the
|
||||
shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and
|
||||
transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1],
|
||||
math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
|
||||
|
||||
args:
|
||||
rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples:
|
||||
number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including
|
||||
|
||||
:return: A tuple of
|
||||
- `channels`
|
||||
- A importance samplers for additional fine-grained rendering
|
||||
- raw model output
|
||||
"""
|
||||
origin, direction = rays[..., 0, :], rays[..., 1, :]
|
||||
|
||||
# Integrate over [t[i], t[i + 1]]
|
||||
|
||||
# 1 Intersect the rays with the current volume and sample ts to integrate along.
|
||||
vrange = self.volume.intersect(origin, direction, t0_lower=None)
|
||||
ts = sampler.sample(vrange.t0, vrange.t1, n_samples)
|
||||
ts = ts.to(rays.dtype)
|
||||
|
||||
if prev_model_out is not None:
|
||||
# Append the previous ts now before fprop because previous
|
||||
# rendering used a different model and we can't reuse the output.
|
||||
ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values
|
||||
|
||||
batch_size, *_shape, _t0_dim = vrange.t0.shape
|
||||
_, *ts_shape, _ts_dim = ts.shape
|
||||
|
||||
# 2. Get the points along the ray and query the model
|
||||
directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])
|
||||
positions = origin.unsqueeze(-2) + ts * directions
|
||||
|
||||
directions = directions.to(self.mlp.dtype)
|
||||
positions = positions.to(self.mlp.dtype)
|
||||
|
||||
optional_directions = directions if render_with_direction else None
|
||||
|
||||
model_out = self.mlp(
|
||||
position=positions,
|
||||
direction=optional_directions,
|
||||
ts=ts,
|
||||
nerf_level="coarse" if prev_model_out is None else "fine",
|
||||
)
|
||||
|
||||
# 3. Integrate the model results
|
||||
channels, weights, transmittance = integrate_samples(
|
||||
vrange, model_out.ts, model_out.density, model_out.channels
|
||||
)
|
||||
|
||||
# 4. Clean up results that do not intersect with the volume.
|
||||
transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance))
|
||||
channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels))
|
||||
# 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
|
||||
channels = channels + transmittance * self.void(origin)
|
||||
|
||||
weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights)
|
||||
|
||||
return channels, weighted_sampler, model_out
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
latents,
|
||||
device,
|
||||
size: int = 64,
|
||||
ray_batch_size: int = 4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
):
|
||||
# project the the paramters from the generated latents
|
||||
projected_params = self.params_proj(latents)
|
||||
|
||||
# update the mlp layers of the renderer
|
||||
for name, param in self.mlp.state_dict().items():
|
||||
if f"nerstf.{name}" in projected_params.keys():
|
||||
param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))
|
||||
|
||||
# create cameras object
|
||||
camera = create_pan_cameras(size)
|
||||
rays = camera.camera_rays
|
||||
rays = rays.to(device)
|
||||
n_batches = rays.shape[1] // ray_batch_size
|
||||
|
||||
coarse_sampler = StratifiedRaySampler()
|
||||
|
||||
images = []
|
||||
|
||||
for idx in range(n_batches):
|
||||
rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]
|
||||
|
||||
# render rays with coarse, stratified samples.
|
||||
_, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples)
|
||||
# Then, render with additional importance-weighted ray samples.
|
||||
channels, _, _ = self.render_rays(
|
||||
rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out
|
||||
)
|
||||
|
||||
images.append(channels)
|
||||
|
||||
images = torch.cat(images, dim=1)
|
||||
images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)
|
||||
|
||||
return images
|
||||
@@ -47,7 +47,11 @@ class DDIMSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,11 @@ class DDIMSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -59,19 +63,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,11 @@ class DDIMParallelSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -44,7 +44,11 @@ class DDPMSchedulerOutput(BaseOutput):
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -57,19 +61,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,11 @@ class DDPMParallelSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -59,19 +63,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -39,19 +43,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -39,19 +43,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -39,19 +43,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -76,7 +77,11 @@ class BrownianTreeNoiseSampler:
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -89,19 +94,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -190,10 +206,16 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if self.state_in_first_order:
|
||||
pos = -1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
pos = 0
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@property
|
||||
@@ -292,6 +314,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
self.mid_point_sigma = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
def _second_order_timesteps(self, sigmas, log_sigmas):
|
||||
def sigma_fn(_t):
|
||||
return np.exp(-_t)
|
||||
@@ -373,6 +399,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
# Create a noise sampler if it hasn't been created yet
|
||||
if self.noise_sampler is None:
|
||||
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
|
||||
|
||||
@@ -29,7 +29,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -42,19 +46,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,11 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,11 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -23,7 +24,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -36,19 +41,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -74,6 +90,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf).
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
|
||||
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
|
||||
@@ -100,6 +120,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
clip_sample: Optional[bool] = False,
|
||||
clip_sample_range: float = 1.0,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
@@ -114,7 +136,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
|
||||
elif beta_schedule == "exp":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
@@ -131,10 +155,16 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if self.state_in_first_order:
|
||||
pos = -1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
pos = 0
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@property
|
||||
@@ -207,7 +237,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.use_karras_sigmas:
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
@@ -228,6 +258,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
@@ -292,6 +326,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
@@ -316,12 +354,17 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample / (sigma_input**2 + 1)
|
||||
)
|
||||
elif self.config.prediction_type == "sample":
|
||||
raise NotImplementedError("prediction_type not implemented yet: sample")
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
if self.state_in_first_order:
|
||||
# 2. Convert to an ODE derivative for 1st order
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -24,7 +25,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -37,19 +42,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -130,10 +146,16 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if self.state_in_first_order:
|
||||
pos = -1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
pos = 0
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@property
|
||||
@@ -245,6 +267,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sample = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = sigma.log()
|
||||
@@ -295,6 +321,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_interpol = self.sigmas_interpol[step_index]
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -23,7 +24,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -36,19 +41,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -129,10 +145,16 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if self.state_in_first_order:
|
||||
pos = -1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
pos = 0
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@property
|
||||
@@ -234,6 +256,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sample = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = sigma.log()
|
||||
@@ -283,6 +309,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_interpol = self.sigmas_interpol[step_index + 1]
|
||||
|
||||
@@ -45,7 +45,11 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -58,19 +62,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -38,19 +42,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -43,7 +43,11 @@ class RePaintSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -56,19 +60,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -44,7 +44,11 @@ class UnCLIPSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
@@ -57,19 +61,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ if is_torch_available():
|
||||
)
|
||||
from .torch_utils import maybe_allow_in_graph
|
||||
|
||||
from .testing_utils import export_to_video
|
||||
from .testing_utils import export_to_gif, export_to_video
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -377,6 +377,36 @@ class SemanticStableDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ShapEImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ShapEPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -300,6 +300,21 @@ def preprocess_image(image: PIL.Image, batch_size: int):
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str:
|
||||
if output_gif_path is None:
|
||||
output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name
|
||||
|
||||
image[0].save(
|
||||
output_gif_path,
|
||||
save_all=True,
|
||||
append_images=image[1:],
|
||||
optimize=False,
|
||||
duration=100,
|
||||
loop=0,
|
||||
)
|
||||
return output_gif_path
|
||||
|
||||
|
||||
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
|
||||
0
tests/pipelines/shap_e/__init__.py
Normal file
0
tests/pipelines/shap_e/__init__.py
Normal file
265
tests/pipelines/shap_e/test_shap_e.py
Normal file
265
tests/pipelines/shap_e/test_shap_e.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEPipeline
|
||||
from diffusers.pipelines.shap_e import ShapERenderer
|
||||
from diffusers.utils import load_numpy, slow
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
|
||||
|
||||
class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = ShapEPipeline
|
||||
params = ["prompt"]
|
||||
batch_params = ["prompt"]
|
||||
required_optional_params = [
|
||||
"num_images_per_prompt",
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"guidance_scale",
|
||||
"frame_size",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_input_dim(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_embed_dim(self):
|
||||
return self.time_input_dim * 4
|
||||
|
||||
@property
|
||||
def renderer_dim(self):
|
||||
return 8
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModelWithProjection(config)
|
||||
|
||||
@property
|
||||
def dummy_prior(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 16,
|
||||
"embedding_dim": self.time_input_dim,
|
||||
"num_embeddings": 32,
|
||||
"embedding_proj_dim": self.text_embedder_hidden_size,
|
||||
"time_embed_dim": self.time_embed_dim,
|
||||
"num_layers": 1,
|
||||
"clip_embed_dim": self.time_input_dim * 2,
|
||||
"additional_embeddings": 0,
|
||||
"time_embed_act_fn": "gelu",
|
||||
"norm_in_type": "layer",
|
||||
"encoder_hid_proj_type": None,
|
||||
"added_emb_type": None,
|
||||
}
|
||||
|
||||
model = PriorTransformer(**model_kwargs)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_renderer(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"param_shapes": (
|
||||
(self.renderer_dim, 93),
|
||||
(self.renderer_dim, 8),
|
||||
(self.renderer_dim, 8),
|
||||
(self.renderer_dim, 8),
|
||||
),
|
||||
"d_latent": self.time_input_dim,
|
||||
"d_hidden": self.renderer_dim,
|
||||
"n_output": 12,
|
||||
"background": (
|
||||
0.1,
|
||||
0.1,
|
||||
0.1,
|
||||
),
|
||||
}
|
||||
model = ShapERenderer(**model_kwargs)
|
||||
return model
|
||||
|
||||
def get_dummy_components(self):
|
||||
prior = self.dummy_prior
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
renderer = self.dummy_renderer
|
||||
|
||||
scheduler = HeunDiscreteScheduler(
|
||||
beta_schedule="exp",
|
||||
num_train_timesteps=1024,
|
||||
prediction_type="sample",
|
||||
use_karras_sigmas=True,
|
||||
clip_sample=True,
|
||||
clip_sample_range=1.0,
|
||||
)
|
||||
components = {
|
||||
"prior": prior,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"renderer": renderer,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "horse",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 1,
|
||||
"frame_size": 32,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_shap_e(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device))
|
||||
image = output.images[0]
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (20, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_batch_consistent(self):
|
||||
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
|
||||
self._test_inference_batch_consistent(batch_sizes=[1, 2])
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
relax_max_difference = True
|
||||
|
||||
self._test_inference_batch_single_identical(
|
||||
batch_size=2,
|
||||
test_max_difference=test_max_difference,
|
||||
relax_max_difference=relax_max_difference,
|
||||
)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_size = 1
|
||||
num_images_per_prompt = 2
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ShapEPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_shap_e(self):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/shap_e/test_shap_e_np_out.npy"
|
||||
)
|
||||
pipe = ShapEPipeline.from_pretrained("openai/shap-e")
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
images = pipe(
|
||||
"a shark",
|
||||
generator=generator,
|
||||
guidance_scale=15.0,
|
||||
num_inference_steps=64,
|
||||
frame_size=64,
|
||||
output_type="np",
|
||||
).images[0]
|
||||
|
||||
assert images.shape == (20, 64, 64, 3)
|
||||
|
||||
assert_mean_pixel_difference(images, expected_image)
|
||||
281
tests/pipelines/shap_e/test_shap_e_img2img.py
Normal file
281
tests/pipelines/shap_e/test_shap_e_img2img.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
|
||||
|
||||
from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEImg2ImgPipeline
|
||||
from diffusers.pipelines.shap_e import ShapERenderer
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, slow
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
|
||||
|
||||
class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = ShapEImg2ImgPipeline
|
||||
params = ["image"]
|
||||
batch_params = ["image"]
|
||||
required_optional_params = [
|
||||
"num_images_per_prompt",
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"guidance_scale",
|
||||
"frame_size",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_input_dim(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_embed_dim(self):
|
||||
return self.time_input_dim * 4
|
||||
|
||||
@property
|
||||
def renderer_dim(self):
|
||||
return 8
|
||||
|
||||
@property
|
||||
def dummy_image_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPVisionConfig(
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
image_size=64,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
num_attention_heads=4,
|
||||
num_channels=3,
|
||||
num_hidden_layers=5,
|
||||
patch_size=1,
|
||||
)
|
||||
|
||||
model = CLIPVisionModel(config)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_image_processor(self):
|
||||
image_processor = CLIPImageProcessor(
|
||||
crop_size=224,
|
||||
do_center_crop=True,
|
||||
do_normalize=True,
|
||||
do_resize=True,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
resample=3,
|
||||
size=224,
|
||||
)
|
||||
|
||||
return image_processor
|
||||
|
||||
@property
|
||||
def dummy_prior(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 16,
|
||||
"embedding_dim": self.time_input_dim,
|
||||
"num_embeddings": 32,
|
||||
"embedding_proj_dim": self.text_embedder_hidden_size,
|
||||
"time_embed_dim": self.time_embed_dim,
|
||||
"num_layers": 1,
|
||||
"clip_embed_dim": self.time_input_dim * 2,
|
||||
"additional_embeddings": 0,
|
||||
"time_embed_act_fn": "gelu",
|
||||
"norm_in_type": "layer",
|
||||
"embedding_proj_norm_type": "layer",
|
||||
"encoder_hid_proj_type": None,
|
||||
"added_emb_type": None,
|
||||
}
|
||||
|
||||
model = PriorTransformer(**model_kwargs)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_renderer(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"param_shapes": (
|
||||
(self.renderer_dim, 93),
|
||||
(self.renderer_dim, 8),
|
||||
(self.renderer_dim, 8),
|
||||
(self.renderer_dim, 8),
|
||||
),
|
||||
"d_latent": self.time_input_dim,
|
||||
"d_hidden": self.renderer_dim,
|
||||
"n_output": 12,
|
||||
"background": (
|
||||
0.1,
|
||||
0.1,
|
||||
0.1,
|
||||
),
|
||||
}
|
||||
model = ShapERenderer(**model_kwargs)
|
||||
return model
|
||||
|
||||
def get_dummy_components(self):
|
||||
prior = self.dummy_prior
|
||||
image_encoder = self.dummy_image_encoder
|
||||
image_processor = self.dummy_image_processor
|
||||
renderer = self.dummy_renderer
|
||||
|
||||
scheduler = HeunDiscreteScheduler(
|
||||
beta_schedule="exp",
|
||||
num_train_timesteps=1024,
|
||||
prediction_type="sample",
|
||||
use_karras_sigmas=True,
|
||||
clip_sample=True,
|
||||
clip_sample_range=1.0,
|
||||
)
|
||||
components = {
|
||||
"prior": prior,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
"renderer": renderer,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
input_image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"image": input_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 1,
|
||||
"frame_size": 32,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_shap_e(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device))
|
||||
image = output.images[0]
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (20, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
0.00039216,
|
||||
]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_batch_consistent(self):
|
||||
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
|
||||
self._test_inference_batch_consistent(batch_sizes=[1, 2])
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
relax_max_difference = True
|
||||
self._test_inference_batch_single_identical(
|
||||
batch_size=2,
|
||||
test_max_difference=test_max_difference,
|
||||
relax_max_difference=relax_max_difference,
|
||||
)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_size = 1
|
||||
num_images_per_prompt = 2
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ShapEImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_shap_e_img2img(self):
|
||||
input_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/shap_e/corgi.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/shap_e/test_shap_e_img2img_out.npy"
|
||||
)
|
||||
pipe = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img")
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
images = pipe(
|
||||
input_image,
|
||||
generator=generator,
|
||||
guidance_scale=3.0,
|
||||
num_inference_steps=64,
|
||||
frame_size=64,
|
||||
output_type="np",
|
||||
).images[0]
|
||||
|
||||
assert images.shape == (20, 64, 64, 3)
|
||||
|
||||
assert_mean_pixel_difference(images, expected_image)
|
||||
@@ -30,11 +30,15 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
for schedule in ["linear", "scaled_linear", "exp"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_clip_sample(self):
|
||||
for clip_sample_range in [1.0, 2.0, 3.0]:
|
||||
self.check_over_configs(clip_sample_range=clip_sample_range, clip_sample=True)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
for prediction_type in ["epsilon", "v_prediction", "sample"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
|
||||
Reference in New Issue
Block a user