1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Shap-E: add support for mesh output (#4062)

* add output_type=mesh

* update img2img

* make style

* add doc

* make style

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* add docstring for output_type

* add a section in doc about hub mesh visualization/ rotation

* update conversion script so default background is white

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* renderer -> shap_e_renderer

* img2img renderer -> shap_e_renderer

* fix tests

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
YiYi Xu
2023-07-20 06:05:13 -10:00
committed by GitHub
parent 07f1fbb18e
commit 47b3346422
9 changed files with 1040 additions and 59 deletions

View File

@@ -128,6 +128,63 @@ gif_path = export_to_gif(images[0], "burger_3d.gif")
```
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_out.gif)
### Generate mesh
For both [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`], you can generate mesh output by passing `output_type` as `mesh` to the pipeline, and then use the [`ShapEPipeline.export_to_ply`] utility function to save the output as a `ply` file. We also provide a [`ShapEPipeline.export_to_obj`] function that you can use to save mesh outputs as `obj` files.
```python
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_ply
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo = "openai/shap-e"
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to(device)
guidance_scale = 15.0
prompt = "A birthday cupcake"
images = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, frame_size=256, output_type="mesh").images
ply_path = export_to_ply(images[0], "3d_cake.ply")
print(f"saved to folder: {ply_path}")
```
Huggingface Datasets supports mesh visualization for mesh files in `glb` format. Below we will show you how to convert your mesh file into `glb` format so that you can use the Dataset viewer to render 3D objects.
We need to install `trimesh` library.
```
pip install trimesh
```
To convert the mesh file into `glb` format,
```python
import trimesh
mesh = trimesh.load("3d_cake.ply")
mesh.export("3d_cake.glb", file_type="glb")
```
By default, the mesh output of Shap-E is from the bottom viewpoint; you can change the default viewpoint by applying a rotation transformation
```python
import trimesh
import numpy as np
mesh = trimesh.load("3d_cake.ply")
rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
mesh = mesh.apply_transform(rot)
mesh.export("3d_cake.glb", file_type="glb")
```
Now you can upload your mesh file to your dataset and visualize it! Here is the link to the 3D cake we just generated
https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/shap_e/3d_cake.glb
## ShapEPipeline
[[autodoc]] ShapEPipeline
- all

View File

@@ -22,7 +22,7 @@ $ python scripts/convert_shap_e_to_diffusers.py \
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
--prior_image_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/image_cond.pt \
--transmitter_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\
--dump_path /home/yiyi_huggingface_co/model_repo/shap-e/renderer\
--dump_path /home/yiyi_huggingface_co/model_repo/shap-e-img2img/shap_e_renderer\
--debug renderer
```
"""
@@ -373,6 +373,487 @@ def prior_image_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
# renderer
## create the lookup table for marching cubes method used in MeshDecoder
MC_TABLE = [
[],
[[0, 1, 0, 2, 0, 4]],
[[1, 0, 1, 5, 1, 3]],
[[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2]],
[[2, 0, 2, 3, 2, 6]],
[[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4]],
[[1, 0, 1, 5, 1, 3], [2, 6, 0, 2, 3, 2]],
[[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4]],
[[3, 1, 3, 7, 3, 2]],
[[0, 2, 0, 4, 0, 1], [3, 7, 2, 3, 1, 3]],
[[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0]],
[[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5]],
[[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6]],
[[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6]],
[[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7]],
[[0, 4, 1, 5, 3, 7], [0, 4, 3, 7, 2, 6]],
[[4, 0, 4, 6, 4, 5]],
[[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1]],
[[1, 5, 1, 3, 1, 0], [4, 6, 5, 4, 0, 4]],
[[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2]],
[[2, 0, 2, 3, 2, 6], [4, 5, 0, 4, 6, 4]],
[[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1]],
[[2, 6, 2, 0, 3, 2], [1, 0, 1, 5, 3, 1], [6, 4, 5, 4, 0, 4]],
[[1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6]],
[[3, 1, 3, 7, 3, 2], [6, 4, 5, 4, 0, 4]],
[[4, 5, 0, 1, 4, 6], [0, 1, 0, 2, 4, 6], [7, 3, 2, 3, 1, 3]],
[[3, 2, 1, 0, 3, 7], [1, 0, 1, 5, 3, 7], [6, 4, 5, 4, 0, 4]],
[[3, 7, 3, 2, 1, 5], [3, 2, 6, 4, 1, 5], [1, 5, 6, 4, 5, 4], [3, 2, 2, 0, 6, 4]],
[[3, 7, 2, 6, 3, 1], [2, 6, 2, 0, 3, 1], [5, 4, 0, 4, 6, 4]],
[[1, 0, 1, 3, 5, 4], [1, 3, 2, 6, 5, 4], [1, 3, 3, 7, 2, 6], [5, 4, 2, 6, 4, 6]],
[[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7], [4, 5, 0, 4, 4, 6]],
[[6, 2, 4, 6, 4, 5], [4, 5, 5, 1, 6, 2], [6, 2, 5, 1, 7, 3]],
[[5, 1, 5, 4, 5, 7]],
[[0, 1, 0, 2, 0, 4], [5, 7, 1, 5, 4, 5]],
[[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3]],
[[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3]],
[[2, 0, 2, 3, 2, 6], [7, 5, 1, 5, 4, 5]],
[[2, 6, 0, 4, 2, 3], [0, 4, 0, 1, 2, 3], [7, 5, 1, 5, 4, 5]],
[[5, 7, 1, 3, 5, 4], [1, 3, 1, 0, 5, 4], [6, 2, 0, 2, 3, 2]],
[[3, 1, 3, 2, 7, 5], [3, 2, 0, 4, 7, 5], [3, 2, 2, 6, 0, 4], [7, 5, 0, 4, 5, 4]],
[[3, 7, 3, 2, 3, 1], [5, 4, 7, 5, 1, 5]],
[[0, 4, 0, 1, 2, 0], [3, 1, 3, 7, 2, 3], [4, 5, 7, 5, 1, 5]],
[[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0]],
[[0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7]],
[[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6], [4, 5, 7, 5, 1, 5]],
[[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6], [5, 7, 1, 5, 5, 4]],
[[2, 6, 2, 0, 3, 7], [2, 0, 4, 5, 3, 7], [3, 7, 4, 5, 7, 5], [2, 0, 0, 1, 4, 5]],
[[4, 0, 5, 4, 5, 7], [5, 7, 7, 3, 4, 0], [4, 0, 7, 3, 6, 2]],
[[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0]],
[[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6]],
[[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7]],
[[0, 2, 4, 6, 5, 7], [0, 2, 5, 7, 1, 3]],
[[5, 1, 4, 0, 5, 7], [4, 0, 4, 6, 5, 7], [3, 2, 6, 2, 0, 2]],
[[2, 3, 2, 6, 0, 1], [2, 6, 7, 5, 0, 1], [0, 1, 7, 5, 1, 5], [2, 6, 6, 4, 7, 5]],
[[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7], [2, 6, 0, 2, 2, 3]],
[[3, 1, 2, 3, 2, 6], [2, 6, 6, 4, 3, 1], [3, 1, 6, 4, 7, 5]],
[[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0], [2, 3, 1, 3, 7, 3]],
[[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6], [3, 2, 1, 3, 3, 7]],
[[0, 1, 0, 4, 2, 3], [0, 4, 5, 7, 2, 3], [0, 4, 4, 6, 5, 7], [2, 3, 5, 7, 3, 7]],
[[7, 5, 3, 7, 3, 2], [3, 2, 2, 0, 7, 5], [7, 5, 2, 0, 6, 4]],
[[0, 4, 4, 6, 5, 7], [0, 4, 5, 7, 1, 5], [0, 2, 1, 3, 3, 7], [3, 7, 2, 6, 0, 2]],
[
[3, 1, 7, 3, 6, 2],
[6, 2, 0, 1, 3, 1],
[6, 4, 0, 1, 6, 2],
[6, 4, 5, 1, 0, 1],
[6, 4, 7, 5, 5, 1],
],
[
[4, 0, 6, 4, 7, 5],
[7, 5, 1, 0, 4, 0],
[7, 3, 1, 0, 7, 5],
[7, 3, 2, 0, 1, 0],
[7, 3, 6, 2, 2, 0],
],
[[7, 3, 6, 2, 6, 4], [7, 5, 7, 3, 6, 4]],
[[6, 2, 6, 7, 6, 4]],
[[0, 4, 0, 1, 0, 2], [6, 7, 4, 6, 2, 6]],
[[1, 0, 1, 5, 1, 3], [7, 6, 4, 6, 2, 6]],
[[1, 3, 0, 2, 1, 5], [0, 2, 0, 4, 1, 5], [7, 6, 4, 6, 2, 6]],
[[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0]],
[[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3]],
[[6, 4, 2, 0, 6, 7], [2, 0, 2, 3, 6, 7], [5, 1, 3, 1, 0, 1]],
[[1, 5, 1, 3, 0, 4], [1, 3, 7, 6, 0, 4], [0, 4, 7, 6, 4, 6], [1, 3, 3, 2, 7, 6]],
[[3, 2, 3, 1, 3, 7], [6, 4, 2, 6, 7, 6]],
[[3, 7, 3, 2, 1, 3], [0, 2, 0, 4, 1, 0], [7, 6, 4, 6, 2, 6]],
[[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0], [4, 6, 2, 6, 7, 6]],
[[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5], [6, 4, 2, 6, 6, 7]],
[[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0]],
[[0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7]],
[[0, 2, 0, 1, 4, 6], [0, 1, 3, 7, 4, 6], [0, 1, 1, 5, 3, 7], [4, 6, 3, 7, 6, 7]],
[[7, 3, 6, 7, 6, 4], [6, 4, 4, 0, 7, 3], [7, 3, 4, 0, 5, 1]],
[[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5]],
[[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5]],
[[6, 7, 4, 5, 6, 2], [4, 5, 4, 0, 6, 2], [3, 1, 0, 1, 5, 1]],
[[2, 0, 2, 6, 3, 1], [2, 6, 4, 5, 3, 1], [2, 6, 6, 7, 4, 5], [3, 1, 4, 5, 1, 5]],
[[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7]],
[[0, 1, 2, 3, 6, 7], [0, 1, 6, 7, 4, 5]],
[[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7], [1, 3, 0, 1, 1, 5]],
[[5, 4, 1, 5, 1, 3], [1, 3, 3, 2, 5, 4], [5, 4, 3, 2, 7, 6]],
[[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5], [1, 3, 7, 3, 2, 3]],
[[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5], [3, 7, 2, 3, 3, 1]],
[[0, 1, 1, 5, 3, 7], [0, 1, 3, 7, 2, 3], [0, 4, 2, 6, 6, 7], [6, 7, 4, 5, 0, 4]],
[
[6, 2, 7, 6, 5, 4],
[5, 4, 0, 2, 6, 2],
[5, 1, 0, 2, 5, 4],
[5, 1, 3, 2, 0, 2],
[5, 1, 7, 3, 3, 2],
],
[[3, 1, 3, 7, 2, 0], [3, 7, 5, 4, 2, 0], [2, 0, 5, 4, 0, 4], [3, 7, 7, 6, 5, 4]],
[[1, 0, 3, 1, 3, 7], [3, 7, 7, 6, 1, 0], [1, 0, 7, 6, 5, 4]],
[
[1, 0, 5, 1, 7, 3],
[7, 3, 2, 0, 1, 0],
[7, 6, 2, 0, 7, 3],
[7, 6, 4, 0, 2, 0],
[7, 6, 5, 4, 4, 0],
],
[[7, 6, 5, 4, 5, 1], [7, 3, 7, 6, 5, 1]],
[[5, 7, 5, 1, 5, 4], [6, 2, 7, 6, 4, 6]],
[[0, 2, 0, 4, 1, 0], [5, 4, 5, 7, 1, 5], [2, 6, 7, 6, 4, 6]],
[[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3], [2, 6, 7, 6, 4, 6]],
[[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3], [6, 7, 4, 6, 6, 2]],
[[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0], [1, 5, 4, 5, 7, 5]],
[[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3], [5, 1, 4, 5, 5, 7]],
[[0, 2, 2, 3, 6, 7], [0, 2, 6, 7, 4, 6], [0, 1, 4, 5, 5, 7], [5, 7, 1, 3, 0, 1]],
[
[5, 4, 7, 5, 3, 1],
[3, 1, 0, 4, 5, 4],
[3, 2, 0, 4, 3, 1],
[3, 2, 6, 4, 0, 4],
[3, 2, 7, 6, 6, 4],
],
[[5, 4, 5, 7, 1, 5], [3, 7, 3, 2, 1, 3], [4, 6, 2, 6, 7, 6]],
[[1, 0, 0, 2, 0, 4], [1, 5, 5, 4, 5, 7], [3, 2, 1, 3, 3, 7], [2, 6, 7, 6, 4, 6]],
[[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0], [6, 2, 7, 6, 6, 4]],
[
[0, 4, 2, 3, 0, 2],
[0, 4, 3, 7, 2, 3],
[0, 4, 4, 5, 3, 7],
[4, 5, 5, 7, 3, 7],
[6, 7, 4, 6, 2, 6],
],
[[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0], [5, 4, 7, 5, 5, 1]],
[
[0, 1, 4, 6, 0, 4],
[0, 1, 6, 7, 4, 6],
[0, 1, 1, 3, 6, 7],
[1, 3, 3, 7, 6, 7],
[5, 7, 1, 5, 4, 5],
],
[
[6, 7, 4, 6, 0, 2],
[0, 2, 3, 7, 6, 7],
[0, 1, 3, 7, 0, 2],
[0, 1, 5, 7, 3, 7],
[0, 1, 4, 5, 5, 7],
],
[[4, 0, 6, 7, 4, 6], [4, 0, 7, 3, 6, 7], [4, 0, 5, 7, 7, 3], [4, 5, 5, 7, 4, 0]],
[[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0]],
[[0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7]],
[[1, 3, 1, 0, 5, 7], [1, 0, 2, 6, 5, 7], [5, 7, 2, 6, 7, 6], [1, 0, 0, 4, 2, 6]],
[[2, 0, 6, 2, 6, 7], [6, 7, 7, 5, 2, 0], [2, 0, 7, 5, 3, 1]],
[[0, 4, 0, 2, 1, 5], [0, 2, 6, 7, 1, 5], [0, 2, 2, 3, 6, 7], [1, 5, 6, 7, 5, 7]],
[[7, 6, 5, 7, 5, 1], [5, 1, 1, 0, 7, 6], [7, 6, 1, 0, 3, 2]],
[
[2, 0, 3, 2, 7, 6],
[7, 6, 4, 0, 2, 0],
[7, 5, 4, 0, 7, 6],
[7, 5, 1, 0, 4, 0],
[7, 5, 3, 1, 1, 0],
],
[[7, 5, 3, 1, 3, 2], [7, 6, 7, 5, 3, 2]],
[[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0], [3, 1, 7, 3, 3, 2]],
[
[0, 2, 1, 5, 0, 1],
[0, 2, 5, 7, 1, 5],
[0, 2, 2, 6, 5, 7],
[2, 6, 6, 7, 5, 7],
[3, 7, 2, 3, 1, 3],
],
[
[3, 7, 2, 3, 0, 1],
[0, 1, 5, 7, 3, 7],
[0, 4, 5, 7, 0, 1],
[0, 4, 6, 7, 5, 7],
[0, 4, 2, 6, 6, 7],
],
[[2, 0, 3, 7, 2, 3], [2, 0, 7, 5, 3, 7], [2, 0, 6, 7, 7, 5], [2, 6, 6, 7, 2, 0]],
[
[5, 7, 1, 5, 0, 4],
[0, 4, 6, 7, 5, 7],
[0, 2, 6, 7, 0, 4],
[0, 2, 3, 7, 6, 7],
[0, 2, 1, 3, 3, 7],
],
[[1, 0, 5, 7, 1, 5], [1, 0, 7, 6, 5, 7], [1, 0, 3, 7, 7, 6], [1, 3, 3, 7, 1, 0]],
[[0, 2, 0, 1, 0, 4], [3, 7, 6, 7, 5, 7]],
[[7, 5, 7, 3, 7, 6]],
[[7, 3, 7, 5, 7, 6]],
[[0, 1, 0, 2, 0, 4], [6, 7, 3, 7, 5, 7]],
[[1, 3, 1, 0, 1, 5], [7, 6, 3, 7, 5, 7]],
[[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2], [6, 7, 3, 7, 5, 7]],
[[2, 6, 2, 0, 2, 3], [7, 5, 6, 7, 3, 7]],
[[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4], [5, 7, 6, 7, 3, 7]],
[[1, 5, 1, 3, 0, 1], [2, 3, 2, 6, 0, 2], [5, 7, 6, 7, 3, 7]],
[[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4], [7, 6, 3, 7, 7, 5]],
[[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2]],
[[7, 6, 3, 2, 7, 5], [3, 2, 3, 1, 7, 5], [4, 0, 1, 0, 2, 0]],
[[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2]],
[[2, 3, 2, 0, 6, 7], [2, 0, 1, 5, 6, 7], [2, 0, 0, 4, 1, 5], [6, 7, 1, 5, 7, 5]],
[[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1]],
[[0, 4, 0, 1, 2, 6], [0, 1, 5, 7, 2, 6], [2, 6, 5, 7, 6, 7], [0, 1, 1, 3, 5, 7]],
[[1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6]],
[[5, 1, 7, 5, 7, 6], [7, 6, 6, 2, 5, 1], [5, 1, 6, 2, 4, 0]],
[[4, 5, 4, 0, 4, 6], [7, 3, 5, 7, 6, 7]],
[[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1], [3, 7, 5, 7, 6, 7]],
[[4, 6, 4, 5, 0, 4], [1, 5, 1, 3, 0, 1], [6, 7, 3, 7, 5, 7]],
[[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2], [7, 3, 5, 7, 7, 6]],
[[2, 3, 2, 6, 0, 2], [4, 6, 4, 5, 0, 4], [3, 7, 5, 7, 6, 7]],
[[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1], [7, 5, 6, 7, 7, 3]],
[[0, 1, 1, 5, 1, 3], [0, 2, 2, 3, 2, 6], [4, 5, 0, 4, 4, 6], [5, 7, 6, 7, 3, 7]],
[
[1, 3, 5, 4, 1, 5],
[1, 3, 4, 6, 5, 4],
[1, 3, 3, 2, 4, 6],
[3, 2, 2, 6, 4, 6],
[7, 6, 3, 7, 5, 7],
],
[[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2], [0, 4, 6, 4, 5, 4]],
[[1, 0, 0, 2, 4, 6], [1, 0, 4, 6, 5, 4], [1, 3, 5, 7, 7, 6], [7, 6, 3, 2, 1, 3]],
[[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2], [4, 6, 5, 4, 4, 0]],
[
[7, 5, 6, 7, 2, 3],
[2, 3, 1, 5, 7, 5],
[2, 0, 1, 5, 2, 3],
[2, 0, 4, 5, 1, 5],
[2, 0, 6, 4, 4, 5],
],
[[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1], [4, 0, 6, 4, 4, 5]],
[
[4, 6, 5, 4, 1, 0],
[1, 0, 2, 6, 4, 6],
[1, 3, 2, 6, 1, 0],
[1, 3, 7, 6, 2, 6],
[1, 3, 5, 7, 7, 6],
],
[
[1, 5, 0, 2, 1, 0],
[1, 5, 2, 6, 0, 2],
[1, 5, 5, 7, 2, 6],
[5, 7, 7, 6, 2, 6],
[4, 6, 5, 4, 0, 4],
],
[[5, 1, 4, 6, 5, 4], [5, 1, 6, 2, 4, 6], [5, 1, 7, 6, 6, 2], [5, 7, 7, 6, 5, 1]],
[[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1]],
[[7, 3, 5, 1, 7, 6], [5, 1, 5, 4, 7, 6], [2, 0, 4, 0, 1, 0]],
[[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4]],
[[0, 2, 0, 4, 1, 3], [0, 4, 6, 7, 1, 3], [1, 3, 6, 7, 3, 7], [0, 4, 4, 5, 6, 7]],
[[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1], [0, 2, 3, 2, 6, 2]],
[[1, 5, 5, 4, 7, 6], [1, 5, 7, 6, 3, 7], [1, 0, 3, 2, 2, 6], [2, 6, 0, 4, 1, 0]],
[[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4], [2, 0, 3, 2, 2, 6]],
[
[2, 3, 6, 2, 4, 0],
[4, 0, 1, 3, 2, 3],
[4, 5, 1, 3, 4, 0],
[4, 5, 7, 3, 1, 3],
[4, 5, 6, 7, 7, 3],
],
[[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6]],
[[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6], [0, 4, 1, 0, 0, 2]],
[[1, 0, 5, 4, 7, 6], [1, 0, 7, 6, 3, 2]],
[[2, 3, 0, 2, 0, 4], [0, 4, 4, 5, 2, 3], [2, 3, 4, 5, 6, 7]],
[[1, 3, 1, 5, 0, 2], [1, 5, 7, 6, 0, 2], [1, 5, 5, 4, 7, 6], [0, 2, 7, 6, 2, 6]],
[
[5, 1, 4, 5, 6, 7],
[6, 7, 3, 1, 5, 1],
[6, 2, 3, 1, 6, 7],
[6, 2, 0, 1, 3, 1],
[6, 2, 4, 0, 0, 1],
],
[[6, 7, 2, 6, 2, 0], [2, 0, 0, 1, 6, 7], [6, 7, 0, 1, 4, 5]],
[[6, 2, 4, 0, 4, 5], [6, 7, 6, 2, 4, 5]],
[[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1]],
[[1, 5, 1, 0, 3, 7], [1, 0, 4, 6, 3, 7], [1, 0, 0, 2, 4, 6], [3, 7, 4, 6, 7, 6]],
[[1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6]],
[[6, 4, 7, 6, 7, 3], [7, 3, 3, 1, 6, 4], [6, 4, 3, 1, 2, 0]],
[[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1], [2, 3, 6, 2, 2, 0]],
[
[7, 6, 3, 7, 1, 5],
[1, 5, 4, 6, 7, 6],
[1, 0, 4, 6, 1, 5],
[1, 0, 2, 6, 4, 6],
[1, 0, 3, 2, 2, 6],
],
[
[1, 0, 3, 7, 1, 3],
[1, 0, 7, 6, 3, 7],
[1, 0, 0, 4, 7, 6],
[0, 4, 4, 6, 7, 6],
[2, 6, 0, 2, 3, 2],
],
[[3, 1, 7, 6, 3, 7], [3, 1, 6, 4, 7, 6], [3, 1, 2, 6, 6, 4], [3, 2, 2, 6, 3, 1]],
[[3, 2, 3, 1, 7, 6], [3, 1, 0, 4, 7, 6], [7, 6, 0, 4, 6, 4], [3, 1, 1, 5, 0, 4]],
[
[0, 1, 2, 0, 6, 4],
[6, 4, 5, 1, 0, 1],
[6, 7, 5, 1, 6, 4],
[6, 7, 3, 1, 5, 1],
[6, 7, 2, 3, 3, 1],
],
[[0, 1, 4, 0, 4, 6], [4, 6, 6, 7, 0, 1], [0, 1, 6, 7, 2, 3]],
[[6, 7, 2, 3, 2, 0], [6, 4, 6, 7, 2, 0]],
[
[2, 6, 0, 2, 1, 3],
[1, 3, 7, 6, 2, 6],
[1, 5, 7, 6, 1, 3],
[1, 5, 4, 6, 7, 6],
[1, 5, 0, 4, 4, 6],
],
[[1, 5, 1, 0, 1, 3], [4, 6, 7, 6, 2, 6]],
[[0, 1, 2, 6, 0, 2], [0, 1, 6, 7, 2, 6], [0, 1, 4, 6, 6, 7], [0, 4, 4, 6, 0, 1]],
[[6, 7, 6, 2, 6, 4]],
[[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4]],
[[7, 5, 6, 4, 7, 3], [6, 4, 6, 2, 7, 3], [1, 0, 2, 0, 4, 0]],
[[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4], [0, 1, 5, 1, 3, 1]],
[[2, 0, 0, 4, 1, 5], [2, 0, 1, 5, 3, 1], [2, 6, 3, 7, 7, 5], [7, 5, 6, 4, 2, 6]],
[[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4]],
[[3, 2, 3, 7, 1, 0], [3, 7, 6, 4, 1, 0], [3, 7, 7, 5, 6, 4], [1, 0, 6, 4, 0, 4]],
[[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4], [1, 5, 3, 1, 1, 0]],
[
[7, 3, 5, 7, 4, 6],
[4, 6, 2, 3, 7, 3],
[4, 0, 2, 3, 4, 6],
[4, 0, 1, 3, 2, 3],
[4, 0, 5, 1, 1, 3],
],
[[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5]],
[[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5], [0, 1, 2, 0, 0, 4]],
[[1, 0, 1, 5, 3, 2], [1, 5, 4, 6, 3, 2], [3, 2, 4, 6, 2, 6], [1, 5, 5, 7, 4, 6]],
[
[0, 2, 4, 0, 5, 1],
[5, 1, 3, 2, 0, 2],
[5, 7, 3, 2, 5, 1],
[5, 7, 6, 2, 3, 2],
[5, 7, 4, 6, 6, 2],
],
[[2, 0, 3, 1, 7, 5], [2, 0, 7, 5, 6, 4]],
[[4, 6, 0, 4, 0, 1], [0, 1, 1, 3, 4, 6], [4, 6, 1, 3, 5, 7]],
[[0, 2, 1, 0, 1, 5], [1, 5, 5, 7, 0, 2], [0, 2, 5, 7, 4, 6]],
[[5, 7, 4, 6, 4, 0], [5, 1, 5, 7, 4, 0]],
[[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2]],
[[0, 1, 0, 2, 4, 5], [0, 2, 3, 7, 4, 5], [4, 5, 3, 7, 5, 7], [0, 2, 2, 6, 3, 7]],
[[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2], [1, 0, 5, 1, 1, 3]],
[
[1, 5, 3, 1, 2, 0],
[2, 0, 4, 5, 1, 5],
[2, 6, 4, 5, 2, 0],
[2, 6, 7, 5, 4, 5],
[2, 6, 3, 7, 7, 5],
],
[[2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5]],
[[3, 2, 7, 3, 7, 5], [7, 5, 5, 4, 3, 2], [3, 2, 5, 4, 1, 0]],
[
[2, 3, 0, 4, 2, 0],
[2, 3, 4, 5, 0, 4],
[2, 3, 3, 7, 4, 5],
[3, 7, 7, 5, 4, 5],
[1, 5, 3, 1, 0, 1],
],
[[3, 2, 1, 5, 3, 1], [3, 2, 5, 4, 1, 5], [3, 2, 7, 5, 5, 4], [3, 7, 7, 5, 3, 2]],
[[2, 6, 2, 3, 0, 4], [2, 3, 7, 5, 0, 4], [2, 3, 3, 1, 7, 5], [0, 4, 7, 5, 4, 5]],
[
[3, 2, 1, 3, 5, 7],
[5, 7, 6, 2, 3, 2],
[5, 4, 6, 2, 5, 7],
[5, 4, 0, 2, 6, 2],
[5, 4, 1, 0, 0, 2],
],
[
[4, 5, 0, 4, 2, 6],
[2, 6, 7, 5, 4, 5],
[2, 3, 7, 5, 2, 6],
[2, 3, 1, 5, 7, 5],
[2, 3, 0, 1, 1, 5],
],
[[2, 3, 2, 0, 2, 6], [1, 5, 7, 5, 4, 5]],
[[5, 7, 4, 5, 4, 0], [4, 0, 0, 2, 5, 7], [5, 7, 0, 2, 1, 3]],
[[5, 4, 1, 0, 1, 3], [5, 7, 5, 4, 1, 3]],
[[0, 2, 4, 5, 0, 4], [0, 2, 5, 7, 4, 5], [0, 2, 1, 5, 5, 7], [0, 1, 1, 5, 0, 2]],
[[5, 4, 5, 1, 5, 7]],
[[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3]],
[[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3], [0, 2, 4, 0, 0, 1]],
[[3, 7, 3, 1, 2, 6], [3, 1, 5, 4, 2, 6], [3, 1, 1, 0, 5, 4], [2, 6, 5, 4, 6, 4]],
[
[6, 4, 2, 6, 3, 7],
[3, 7, 5, 4, 6, 4],
[3, 1, 5, 4, 3, 7],
[3, 1, 0, 4, 5, 4],
[3, 1, 2, 0, 0, 4],
],
[[2, 0, 2, 3, 6, 4], [2, 3, 1, 5, 6, 4], [6, 4, 1, 5, 4, 5], [2, 3, 3, 7, 1, 5]],
[
[0, 4, 1, 0, 3, 2],
[3, 2, 6, 4, 0, 4],
[3, 7, 6, 4, 3, 2],
[3, 7, 5, 4, 6, 4],
[3, 7, 1, 5, 5, 4],
],
[
[1, 3, 0, 1, 4, 5],
[4, 5, 7, 3, 1, 3],
[4, 6, 7, 3, 4, 5],
[4, 6, 2, 3, 7, 3],
[4, 6, 0, 2, 2, 3],
],
[[3, 7, 3, 1, 3, 2], [5, 4, 6, 4, 0, 4]],
[[3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4]],
[
[3, 1, 2, 6, 3, 2],
[3, 1, 6, 4, 2, 6],
[3, 1, 1, 5, 6, 4],
[1, 5, 5, 4, 6, 4],
[0, 4, 1, 0, 2, 0],
],
[[4, 5, 6, 4, 6, 2], [6, 2, 2, 3, 4, 5], [4, 5, 2, 3, 0, 1]],
[[2, 3, 6, 4, 2, 6], [2, 3, 4, 5, 6, 4], [2, 3, 0, 4, 4, 5], [2, 0, 0, 4, 2, 3]],
[[1, 3, 5, 1, 5, 4], [5, 4, 4, 6, 1, 3], [1, 3, 4, 6, 0, 2]],
[[1, 3, 0, 4, 1, 0], [1, 3, 4, 6, 0, 4], [1, 3, 5, 4, 4, 6], [1, 5, 5, 4, 1, 3]],
[[4, 6, 0, 2, 0, 1], [4, 5, 4, 6, 0, 1]],
[[4, 6, 4, 0, 4, 5]],
[[4, 0, 6, 2, 7, 3], [4, 0, 7, 3, 5, 1]],
[[1, 5, 0, 1, 0, 2], [0, 2, 2, 6, 1, 5], [1, 5, 2, 6, 3, 7]],
[[3, 7, 1, 3, 1, 0], [1, 0, 0, 4, 3, 7], [3, 7, 0, 4, 2, 6]],
[[3, 1, 2, 0, 2, 6], [3, 7, 3, 1, 2, 6]],
[[0, 4, 2, 0, 2, 3], [2, 3, 3, 7, 0, 4], [0, 4, 3, 7, 1, 5]],
[[3, 7, 1, 5, 1, 0], [3, 2, 3, 7, 1, 0]],
[[0, 4, 1, 3, 0, 1], [0, 4, 3, 7, 1, 3], [0, 4, 2, 3, 3, 7], [0, 2, 2, 3, 0, 4]],
[[3, 7, 3, 1, 3, 2]],
[[2, 6, 3, 2, 3, 1], [3, 1, 1, 5, 2, 6], [2, 6, 1, 5, 0, 4]],
[[1, 5, 3, 2, 1, 3], [1, 5, 2, 6, 3, 2], [1, 5, 0, 2, 2, 6], [1, 0, 0, 2, 1, 5]],
[[2, 3, 0, 1, 0, 4], [2, 6, 2, 3, 0, 4]],
[[2, 3, 2, 0, 2, 6]],
[[1, 5, 0, 4, 0, 2], [1, 3, 1, 5, 0, 2]],
[[1, 5, 1, 0, 1, 3]],
[[0, 2, 0, 1, 0, 4]],
[],
]
def create_mc_lookup_table():
cases = torch.zeros(256, 5, 3, dtype=torch.long)
masks = torch.zeros(256, 5, dtype=torch.bool)
edge_to_index = {
(0, 1): 0,
(2, 3): 1,
(4, 5): 2,
(6, 7): 3,
(0, 2): 4,
(1, 3): 5,
(4, 6): 6,
(5, 7): 7,
(0, 4): 8,
(1, 5): 9,
(2, 6): 10,
(3, 7): 11,
}
for i, case in enumerate(MC_TABLE):
for j, tri in enumerate(case):
for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])):
cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)]
masks[i, j] = True
return cases, masks
RENDERER_CONFIG = {}
@@ -400,7 +881,12 @@ def renderer_model_original_checkpoint_to_diffusers_checkpoint(model, checkpoint
}
)
diffusers_checkpoint.update({"void.background": torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)})
diffusers_checkpoint.update({"void.background": model.state_dict()["void.background"]})
cases, masks = create_mc_lookup_table()
diffusers_checkpoint.update({"mesh_decoder.cases": cases})
diffusers_checkpoint.update({"mesh_decoder.masks": masks})
return diffusers_checkpoint

View File

@@ -95,7 +95,7 @@ class ShapEPipeline(DiffusionPipeline):
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`HeunDiscreteScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
renderer ([`ShapERenderer`]):
shap_e_renderer ([`ShapERenderer`]):
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
with the NeRF rendering method
"""
@@ -106,7 +106,7 @@ class ShapEPipeline(DiffusionPipeline):
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
scheduler: HeunDiscreteScheduler,
renderer: ShapERenderer,
shap_e_renderer: ShapERenderer,
):
super().__init__()
@@ -115,7 +115,7 @@ class ShapEPipeline(DiffusionPipeline):
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
renderer=renderer,
shap_e_renderer=shap_e_renderer,
)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
@@ -149,7 +149,7 @@ class ShapEPipeline(DiffusionPipeline):
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]:
for cpu_offloaded_model in [self.text_encoder, self.prior, self.shap_e_renderer]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
if self.safety_checker is not None:
@@ -218,7 +218,7 @@ class ShapEPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 4.0,
frame_size: int = 64,
output_type: Optional[str] = "pil", # pil, np, latent
output_type: Optional[str] = "pil", # pil, np, latent, mesh
return_dict: bool = True,
):
"""
@@ -248,8 +248,8 @@ class ShapEPipeline(DiffusionPipeline):
frame_size (`int`, *optional*, default to 64):
the width and height of each image frame of the generated 3d output
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
(`torch.Tensor`).
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
@@ -319,30 +319,39 @@ class ShapEPipeline(DiffusionPipeline):
sample=latents,
).prev_sample
if output_type not in ["np", "pil", "latent", "mesh"]:
raise ValueError(
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
)
if output_type == "latent":
return ShapEPipelineOutput(images=latents)
images = []
for i, latent in enumerate(latents):
image = self.renderer.decode(
latent[None, :],
device,
size=frame_size,
ray_batch_size=4096,
n_coarse_samples=64,
n_fine_samples=128,
)
images.append(image)
if output_type == "mesh":
for i, latent in enumerate(latents):
mesh = self.shap_e_renderer.decode_to_mesh(
latent[None, :],
device,
)
images.append(mesh)
images = torch.stack(images)
else:
# np, pil
for i, latent in enumerate(latents):
image = self.shap_e_renderer.decode_to_image(
latent[None, :],
device,
size=frame_size,
)
images.append(image)
if output_type not in ["np", "pil"]:
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
images = torch.stack(images)
images = images.cpu().numpy()
images = images.cpu().numpy()
if output_type == "pil":
images = [self.numpy_to_pil(image) for image in images]
if output_type == "pil":
images = [self.numpy_to_pil(image) for image in images]
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

View File

@@ -94,7 +94,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`HeunDiscreteScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
renderer ([`ShapERenderer`]):
shap_e_renderer ([`ShapERenderer`]):
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
with the NeRF rendering method
"""
@@ -105,7 +105,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
image_encoder: CLIPVisionModel,
image_processor: CLIPImageProcessor,
scheduler: HeunDiscreteScheduler,
renderer: ShapERenderer,
shap_e_renderer: ShapERenderer,
):
super().__init__()
@@ -114,7 +114,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
image_encoder=image_encoder,
image_processor=image_processor,
scheduler=scheduler,
renderer=renderer,
shap_e_renderer=shap_e_renderer,
)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
@@ -170,7 +170,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 4.0,
frame_size: int = 64,
output_type: Optional[str] = "pil", # pil, np, latent
output_type: Optional[str] = "pil", # pil, np, latent, mesh
return_dict: bool = True,
):
"""
@@ -200,8 +200,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
frame_size (`int`, *optional*, default to 64):
the width and height of each image frame of the generated 3d output
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
(`torch.Tensor`).
(`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
@@ -275,32 +274,39 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
sample=latents,
).prev_sample
if output_type not in ["np", "pil", "latent", "mesh"]:
raise ValueError(
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
)
if output_type == "latent":
return ShapEPipelineOutput(images=latents)
images = []
for i, latent in enumerate(latents):
print()
image = self.renderer.decode(
latent[None, :],
device,
size=frame_size,
ray_batch_size=4096,
n_coarse_samples=64,
n_fine_samples=128,
)
if output_type == "mesh":
for i, latent in enumerate(latents):
mesh = self.shap_e_renderer.decode_to_mesh(
latent[None, :],
device,
)
images.append(mesh)
images.append(image)
else:
# np, pil
for i, latent in enumerate(latents):
image = self.shap_e_renderer.decode_to_image(
latent[None, :],
device,
size=frame_size,
)
images.append(image)
images = torch.stack(images)
images = torch.stack(images)
if output_type not in ["np", "pil"]:
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
images = images.cpu().numpy()
images = images.cpu().numpy()
if output_type == "pil":
images = [self.numpy_to_pil(image) for image in images]
if output_type == "pil":
images = [self.numpy_to_pil(image) for image in images]
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

View File

@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple
import numpy as np
import torch
@@ -116,6 +116,101 @@ def integrate_samples(volume_range, ts, density, channels):
return channels, weights, transmittance
def volume_query_points(volume, grid_size):
indices = torch.arange(grid_size**3, device=volume.bbox_min.device)
zs = indices % grid_size
ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size
xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size
combined = torch.stack([xs, ys, zs], dim=1)
return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min
def _convert_srgb_to_linear(u: torch.Tensor):
return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4)
def _create_flat_edge_indices(
flat_cube_indices: torch.Tensor,
grid_size: Tuple[int, int, int],
):
num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2]
y_offset = num_xs
num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2]
z_offset = num_xs + num_ys
return torch.stack(
[
# Edges spanning x-axis.
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2],
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
+ flat_cube_indices[:, 2],
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1,
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1,
# Edges spanning y-axis.
(
y_offset
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
),
(
y_offset
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
),
(
y_offset
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1
),
(
y_offset
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1
),
# Edges spanning z-axis.
(
z_offset
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
(
z_offset
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
(
z_offset
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
(
z_offset
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
],
dim=-1,
)
class VoidNeRFModel(nn.Module):
"""
Implements the default empty space model where all queries are rendered as background.
@@ -368,6 +463,141 @@ class ImportanceRaySampler(nn.Module):
return ts
@dataclass
class MeshDecoderOutput(BaseOutput):
"""
A 3D triangle mesh with optional data at the vertices and faces.
Args:
verts (`torch.Tensor` of shape `(N, 3)`):
array of vertext coordinates
faces (`torch.Tensor` of shape `(N, 3)`):
array of triangles, pointing to indices in verts.
vertext_channels (Dict):
vertext coordinates for each color channel
"""
verts: torch.Tensor
faces: torch.Tensor
vertex_channels: Dict[str, torch.Tensor]
class MeshDecoder(nn.Module):
"""
Construct meshes from Signed distance functions (SDFs) using marching cubes method
"""
def __init__(self):
super().__init__()
cases = torch.zeros(256, 5, 3, dtype=torch.long)
masks = torch.zeros(256, 5, dtype=torch.bool)
self.register_buffer("cases", cases)
self.register_buffer("masks", masks)
def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor):
"""
For a signed distance field, produce a mesh using marching cubes.
:param field: a 3D tensor of field values, where negative values correspond
to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively.
:param min_point: a tensor of shape [3] containing the point corresponding
to (0, 0, 0) in the field.
:param size: a tensor of shape [3] containing the per-axis distance from the
(0, 0, 0) field corner and the (-1, -1, -1) field corner.
"""
assert len(field.shape) == 3, "input must be a 3D scalar field"
dev = field.device
cases = self.cases.to(dev)
masks = self.masks.to(dev)
min_point = min_point.to(dev)
size = size.to(dev)
grid_size = field.shape
grid_size_tensor = torch.tensor(grid_size).to(size)
# Create bitmasks between 0 and 255 (inclusive) indicating the state
# of the eight corners of each cube.
bitmasks = (field > 0).to(torch.uint8)
bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1)
bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2)
bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4)
# Compute corner coordinates across the entire grid.
corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype)
corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[
:, None, None
]
corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[
:, None
]
corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype)
# Compute all vertices across all edges in the grid, even though we will
# throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices.
# These are all midpoints, and don't account for interpolation (which is
# done later based on the used edge midpoints).
edge_midpoints = torch.cat(
[
((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3),
((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3),
((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3),
],
dim=0,
)
# Create a flat array of [X, Y, Z] indices for each cube.
cube_indices = torch.zeros(
grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long
)
cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None]
cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None]
cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev)
flat_cube_indices = cube_indices.reshape(-1, 3)
# Create a flat array mapping each cube to 12 global edge indices.
edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size)
# Apply the LUT to figure out the triangles.
flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask
local_tris = cases[flat_bitmasks]
local_masks = masks[flat_bitmasks]
# Compute the global edge indices for the triangles.
global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape(
local_tris.shape
)
# Select the used triangles for each cube.
selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)]
# Now we have a bunch of indices into the full list of possible vertices,
# but we want to reduce this list to only the used vertices.
used_vertex_indices = torch.unique(selected_tris.view(-1))
used_edge_midpoints = edge_midpoints[used_vertex_indices]
old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long)
old_index_to_new_index[used_vertex_indices] = torch.arange(
len(used_vertex_indices), device=dev, dtype=torch.long
)
# Rewrite the triangles to use the new indices
faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape)
# Compute the actual interpolated coordinates corresponding to edge midpoints.
v1 = torch.floor(used_edge_midpoints).to(torch.long)
v2 = torch.ceil(used_edge_midpoints).to(torch.long)
s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]]
s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]]
p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point
p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point
# The signs of s1 and s2 should be different. We want to find
# t such that t*s2 + (1-t)*s1 = 0.
t = (s1 / (s1 - s2))[:, None]
verts = t * p2 + (1 - t) * p1
return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None)
@dataclass
class MLPNeRFModelOutput(BaseOutput):
density: torch.Tensor
@@ -429,7 +659,7 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin):
return mapped_output
def forward(self, *, position, direction, ts, nerf_level="coarse"):
def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"):
h = encode_position(position)
h_preact = h
@@ -455,10 +685,17 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin):
if nerf_level == "coarse":
h_density = activation["density_coarse"]
h_channels = activation["nerf_coarse"]
else:
h_density = activation["density_fine"]
h_channels = activation["nerf_fine"]
if rendering_mode == "nerf":
if nerf_level == "coarse":
h_channels = activation["nerf_coarse"]
else:
h_channels = activation["nerf_fine"]
elif rendering_mode == "stf":
h_channels = activation["stf"]
density = self.density_activation(h_density)
signed_distance = self.sdf_activation(activation["sdf"])
@@ -583,6 +820,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
self.void = VoidNeRFModel(background=background, channel_scale=255.0)
self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
self.mesh_decoder = MeshDecoder()
@torch.no_grad()
def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False):
@@ -664,7 +902,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
return channels, weighted_sampler, model_out
@torch.no_grad()
def decode(
def decode_to_image(
self,
latents,
device,
@@ -707,3 +945,106 @@ class ShapERenderer(ModelMixin, ConfigMixin):
images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)
return images
@torch.no_grad()
def decode_to_mesh(
self,
latents,
device,
grid_size: int = 128,
query_batch_size: int = 4096,
texture_channels: Tuple = ("R", "G", "B"),
):
# 1. project the the paramters from the generated latents
projected_params = self.params_proj(latents)
# 2. update the mlp layers of the renderer
for name, param in self.mlp.state_dict().items():
if f"nerstf.{name}" in projected_params.keys():
param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))
# 3. decoding with STF rendering
# 3.1 query the SDF values at vertices along a regular 128**3 grid
query_points = volume_query_points(self.volume, grid_size)
query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype)
fields = []
for idx in range(0, query_positions.shape[1], query_batch_size):
query_batch = query_positions[:, idx : idx + query_batch_size]
model_out = self.mlp(
position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf"
)
fields.append(model_out.signed_distance)
# predicted SDF values
fields = torch.cat(fields, dim=1)
fields = fields.float()
assert (
len(fields.shape) == 3 and fields.shape[-1] == 1
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
fields = fields.reshape(1, *([grid_size] * 3))
# create grid 128 x 128 x 128
# - force a negative border around the SDFs to close off all the models.
full_grid = torch.zeros(
1,
grid_size + 2,
grid_size + 2,
grid_size + 2,
device=fields.device,
dtype=fields.dtype,
)
full_grid.fill_(-1.0)
full_grid[:, 1:-1, 1:-1, 1:-1] = fields
fields = full_grid
# apply a differentiable implementation of Marching Cubes to construct meshs
raw_meshes = []
mesh_mask = []
for field in fields:
raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min)
mesh_mask.append(True)
raw_meshes.append(raw_mesh)
mesh_mask = torch.tensor(mesh_mask, device=fields.device)
max_vertices = max(len(m.verts) for m in raw_meshes)
# 3.2. query the texture color head at each vertex of the resulting mesh.
texture_query_positions = torch.stack(
[m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes],
dim=0,
)
texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype)
textures = []
for idx in range(0, texture_query_positions.shape[1], query_batch_size):
query_batch = texture_query_positions[:, idx : idx + query_batch_size]
texture_model_out = self.mlp(
position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf"
)
textures.append(texture_model_out.channels)
# predict texture color
textures = torch.cat(textures, dim=1)
textures = _convert_srgb_to_linear(textures)
textures = textures.float()
# 3.3 augument the mesh with texture data
assert len(textures.shape) == 3 and textures.shape[-1] == len(
texture_channels
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)]
m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1)))
return raw_meshes[0]

View File

@@ -103,7 +103,7 @@ if is_torch_available():
)
from .torch_utils import maybe_allow_in_graph
from .testing_utils import export_to_gif, export_to_video
from .testing_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
logger = get_logger(__name__)

View File

@@ -1,12 +1,15 @@
import inspect
import io
import logging
import multiprocessing
import os
import random
import re
import struct
import tempfile
import unittest
import urllib.parse
from contextlib import contextmanager
from distutils.util import strtobool
from io import BytesIO, StringIO
from pathlib import Path
@@ -315,6 +318,85 @@ def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) ->
return output_gif_path
@contextmanager
def buffered_writer(raw_f):
f = io.BufferedWriter(raw_f)
yield f
f.flush()
def export_to_ply(mesh, output_ply_path: str = None):
"""
Write a PLY file for a mesh.
"""
if output_ply_path is None:
output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name
coords = mesh.verts.detach().cpu().numpy()
faces = mesh.faces.cpu().numpy()
rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
with buffered_writer(open(output_ply_path, "wb")) as f:
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
if rgb is not None:
f.write(b"property uchar red\n")
f.write(b"property uchar green\n")
f.write(b"property uchar blue\n")
if faces is not None:
f.write(bytes(f"element face {len(faces)}\n", "ascii"))
f.write(b"property list uchar int vertex_index\n")
f.write(b"end_header\n")
if rgb is not None:
rgb = (rgb * 255.499).round().astype(int)
vertices = [
(*coord, *rgb)
for coord, rgb in zip(
coords.tolist(),
rgb.tolist(),
)
]
format = struct.Struct("<3f3B")
for item in vertices:
f.write(format.pack(*item))
else:
format = struct.Struct("<3f")
for vertex in coords.tolist():
f.write(format.pack(*vertex))
if faces is not None:
format = struct.Struct("<B3I")
for tri in faces.tolist():
f.write(format.pack(len(tri), *tri))
return output_ply_path
def export_to_obj(mesh, output_obj_path: str = None):
if output_obj_path is None:
output_obj_path = tempfile.NamedTemporaryFile(suffix=".obj").name
verts = mesh.verts.detach().cpu().numpy()
faces = mesh.faces.cpu().numpy()
vertex_colors = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
vertices = [
"{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(verts.tolist(), vertex_colors.tolist())
]
faces = ["f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in faces.tolist()]
combined_data = ["v " + vertex for vertex in vertices] + faces
with open(output_obj_path, "w") as f:
f.writelines("\n".join(combined_data))
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
if is_opencv_available():
import cv2

View File

@@ -131,7 +131,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
prior = self.dummy_prior
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
renderer = self.dummy_renderer
shap_e_renderer = self.dummy_renderer
scheduler = HeunDiscreteScheduler(
beta_schedule="exp",
@@ -145,7 +145,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"prior": prior,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"renderer": renderer,
"shap_e_renderer": shap_e_renderer,
"scheduler": scheduler,
}

View File

@@ -143,7 +143,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
prior = self.dummy_prior
image_encoder = self.dummy_image_encoder
image_processor = self.dummy_image_processor
renderer = self.dummy_renderer
shap_e_renderer = self.dummy_renderer
scheduler = HeunDiscreteScheduler(
beta_schedule="exp",
@@ -157,7 +157,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"prior": prior,
"image_encoder": image_encoder,
"image_processor": image_processor,
"renderer": renderer,
"shap_e_renderer": shap_e_renderer,
"scheduler": scheduler,
}