mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Shap-E: add support for mesh output (#4062)
* add output_type=mesh * update img2img * make style * add doc * make style * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * add docstring for output_type * add a section in doc about hub mesh visualization/ rotation * update conversion script so default background is white * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * renderer -> shap_e_renderer * img2img renderer -> shap_e_renderer * fix tests --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -128,6 +128,63 @@ gif_path = export_to_gif(images[0], "burger_3d.gif")
|
||||
```
|
||||

|
||||
|
||||
### Generate mesh
|
||||
|
||||
For both [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`], you can generate mesh output by passing `output_type` as `mesh` to the pipeline, and then use the [`ShapEPipeline.export_to_ply`] utility function to save the output as a `ply` file. We also provide a [`ShapEPipeline.export_to_obj`] function that you can use to save mesh outputs as `obj` files.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import export_to_ply
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
repo = "openai/shap-e"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16, variant="fp16")
|
||||
pipe = pipe.to(device)
|
||||
|
||||
guidance_scale = 15.0
|
||||
prompt = "A birthday cupcake"
|
||||
|
||||
images = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, frame_size=256, output_type="mesh").images
|
||||
|
||||
ply_path = export_to_ply(images[0], "3d_cake.ply")
|
||||
print(f"saved to folder: {ply_path}")
|
||||
```
|
||||
|
||||
Huggingface Datasets supports mesh visualization for mesh files in `glb` format. Below we will show you how to convert your mesh file into `glb` format so that you can use the Dataset viewer to render 3D objects.
|
||||
|
||||
We need to install `trimesh` library.
|
||||
|
||||
```
|
||||
pip install trimesh
|
||||
```
|
||||
|
||||
To convert the mesh file into `glb` format,
|
||||
|
||||
```python
|
||||
import trimesh
|
||||
|
||||
mesh = trimesh.load("3d_cake.ply")
|
||||
mesh.export("3d_cake.glb", file_type="glb")
|
||||
```
|
||||
|
||||
By default, the mesh output of Shap-E is from the bottom viewpoint; you can change the default viewpoint by applying a rotation transformation
|
||||
|
||||
```python
|
||||
import trimesh
|
||||
import numpy as np
|
||||
|
||||
mesh = trimesh.load("3d_cake.ply")
|
||||
rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
|
||||
mesh = mesh.apply_transform(rot)
|
||||
mesh.export("3d_cake.glb", file_type="glb")
|
||||
```
|
||||
|
||||
Now you can upload your mesh file to your dataset and visualize it! Here is the link to the 3D cake we just generated
|
||||
https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/shap_e/3d_cake.glb
|
||||
|
||||
## ShapEPipeline
|
||||
[[autodoc]] ShapEPipeline
|
||||
- all
|
||||
|
||||
@@ -22,7 +22,7 @@ $ python scripts/convert_shap_e_to_diffusers.py \
|
||||
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
|
||||
--prior_image_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/image_cond.pt \
|
||||
--transmitter_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\
|
||||
--dump_path /home/yiyi_huggingface_co/model_repo/shap-e/renderer\
|
||||
--dump_path /home/yiyi_huggingface_co/model_repo/shap-e-img2img/shap_e_renderer\
|
||||
--debug renderer
|
||||
```
|
||||
"""
|
||||
@@ -373,6 +373,487 @@ def prior_image_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
|
||||
# renderer
|
||||
|
||||
## create the lookup table for marching cubes method used in MeshDecoder
|
||||
|
||||
MC_TABLE = [
|
||||
[],
|
||||
[[0, 1, 0, 2, 0, 4]],
|
||||
[[1, 0, 1, 5, 1, 3]],
|
||||
[[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2]],
|
||||
[[2, 0, 2, 3, 2, 6]],
|
||||
[[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4]],
|
||||
[[1, 0, 1, 5, 1, 3], [2, 6, 0, 2, 3, 2]],
|
||||
[[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4]],
|
||||
[[3, 1, 3, 7, 3, 2]],
|
||||
[[0, 2, 0, 4, 0, 1], [3, 7, 2, 3, 1, 3]],
|
||||
[[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0]],
|
||||
[[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5]],
|
||||
[[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6]],
|
||||
[[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6]],
|
||||
[[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7]],
|
||||
[[0, 4, 1, 5, 3, 7], [0, 4, 3, 7, 2, 6]],
|
||||
[[4, 0, 4, 6, 4, 5]],
|
||||
[[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1]],
|
||||
[[1, 5, 1, 3, 1, 0], [4, 6, 5, 4, 0, 4]],
|
||||
[[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2]],
|
||||
[[2, 0, 2, 3, 2, 6], [4, 5, 0, 4, 6, 4]],
|
||||
[[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1]],
|
||||
[[2, 6, 2, 0, 3, 2], [1, 0, 1, 5, 3, 1], [6, 4, 5, 4, 0, 4]],
|
||||
[[1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6]],
|
||||
[[3, 1, 3, 7, 3, 2], [6, 4, 5, 4, 0, 4]],
|
||||
[[4, 5, 0, 1, 4, 6], [0, 1, 0, 2, 4, 6], [7, 3, 2, 3, 1, 3]],
|
||||
[[3, 2, 1, 0, 3, 7], [1, 0, 1, 5, 3, 7], [6, 4, 5, 4, 0, 4]],
|
||||
[[3, 7, 3, 2, 1, 5], [3, 2, 6, 4, 1, 5], [1, 5, 6, 4, 5, 4], [3, 2, 2, 0, 6, 4]],
|
||||
[[3, 7, 2, 6, 3, 1], [2, 6, 2, 0, 3, 1], [5, 4, 0, 4, 6, 4]],
|
||||
[[1, 0, 1, 3, 5, 4], [1, 3, 2, 6, 5, 4], [1, 3, 3, 7, 2, 6], [5, 4, 2, 6, 4, 6]],
|
||||
[[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7], [4, 5, 0, 4, 4, 6]],
|
||||
[[6, 2, 4, 6, 4, 5], [4, 5, 5, 1, 6, 2], [6, 2, 5, 1, 7, 3]],
|
||||
[[5, 1, 5, 4, 5, 7]],
|
||||
[[0, 1, 0, 2, 0, 4], [5, 7, 1, 5, 4, 5]],
|
||||
[[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3]],
|
||||
[[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3]],
|
||||
[[2, 0, 2, 3, 2, 6], [7, 5, 1, 5, 4, 5]],
|
||||
[[2, 6, 0, 4, 2, 3], [0, 4, 0, 1, 2, 3], [7, 5, 1, 5, 4, 5]],
|
||||
[[5, 7, 1, 3, 5, 4], [1, 3, 1, 0, 5, 4], [6, 2, 0, 2, 3, 2]],
|
||||
[[3, 1, 3, 2, 7, 5], [3, 2, 0, 4, 7, 5], [3, 2, 2, 6, 0, 4], [7, 5, 0, 4, 5, 4]],
|
||||
[[3, 7, 3, 2, 3, 1], [5, 4, 7, 5, 1, 5]],
|
||||
[[0, 4, 0, 1, 2, 0], [3, 1, 3, 7, 2, 3], [4, 5, 7, 5, 1, 5]],
|
||||
[[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0]],
|
||||
[[0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7]],
|
||||
[[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6], [4, 5, 7, 5, 1, 5]],
|
||||
[[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6], [5, 7, 1, 5, 5, 4]],
|
||||
[[2, 6, 2, 0, 3, 7], [2, 0, 4, 5, 3, 7], [3, 7, 4, 5, 7, 5], [2, 0, 0, 1, 4, 5]],
|
||||
[[4, 0, 5, 4, 5, 7], [5, 7, 7, 3, 4, 0], [4, 0, 7, 3, 6, 2]],
|
||||
[[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0]],
|
||||
[[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6]],
|
||||
[[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7]],
|
||||
[[0, 2, 4, 6, 5, 7], [0, 2, 5, 7, 1, 3]],
|
||||
[[5, 1, 4, 0, 5, 7], [4, 0, 4, 6, 5, 7], [3, 2, 6, 2, 0, 2]],
|
||||
[[2, 3, 2, 6, 0, 1], [2, 6, 7, 5, 0, 1], [0, 1, 7, 5, 1, 5], [2, 6, 6, 4, 7, 5]],
|
||||
[[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7], [2, 6, 0, 2, 2, 3]],
|
||||
[[3, 1, 2, 3, 2, 6], [2, 6, 6, 4, 3, 1], [3, 1, 6, 4, 7, 5]],
|
||||
[[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0], [2, 3, 1, 3, 7, 3]],
|
||||
[[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6], [3, 2, 1, 3, 3, 7]],
|
||||
[[0, 1, 0, 4, 2, 3], [0, 4, 5, 7, 2, 3], [0, 4, 4, 6, 5, 7], [2, 3, 5, 7, 3, 7]],
|
||||
[[7, 5, 3, 7, 3, 2], [3, 2, 2, 0, 7, 5], [7, 5, 2, 0, 6, 4]],
|
||||
[[0, 4, 4, 6, 5, 7], [0, 4, 5, 7, 1, 5], [0, 2, 1, 3, 3, 7], [3, 7, 2, 6, 0, 2]],
|
||||
[
|
||||
[3, 1, 7, 3, 6, 2],
|
||||
[6, 2, 0, 1, 3, 1],
|
||||
[6, 4, 0, 1, 6, 2],
|
||||
[6, 4, 5, 1, 0, 1],
|
||||
[6, 4, 7, 5, 5, 1],
|
||||
],
|
||||
[
|
||||
[4, 0, 6, 4, 7, 5],
|
||||
[7, 5, 1, 0, 4, 0],
|
||||
[7, 3, 1, 0, 7, 5],
|
||||
[7, 3, 2, 0, 1, 0],
|
||||
[7, 3, 6, 2, 2, 0],
|
||||
],
|
||||
[[7, 3, 6, 2, 6, 4], [7, 5, 7, 3, 6, 4]],
|
||||
[[6, 2, 6, 7, 6, 4]],
|
||||
[[0, 4, 0, 1, 0, 2], [6, 7, 4, 6, 2, 6]],
|
||||
[[1, 0, 1, 5, 1, 3], [7, 6, 4, 6, 2, 6]],
|
||||
[[1, 3, 0, 2, 1, 5], [0, 2, 0, 4, 1, 5], [7, 6, 4, 6, 2, 6]],
|
||||
[[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0]],
|
||||
[[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3]],
|
||||
[[6, 4, 2, 0, 6, 7], [2, 0, 2, 3, 6, 7], [5, 1, 3, 1, 0, 1]],
|
||||
[[1, 5, 1, 3, 0, 4], [1, 3, 7, 6, 0, 4], [0, 4, 7, 6, 4, 6], [1, 3, 3, 2, 7, 6]],
|
||||
[[3, 2, 3, 1, 3, 7], [6, 4, 2, 6, 7, 6]],
|
||||
[[3, 7, 3, 2, 1, 3], [0, 2, 0, 4, 1, 0], [7, 6, 4, 6, 2, 6]],
|
||||
[[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0], [4, 6, 2, 6, 7, 6]],
|
||||
[[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5], [6, 4, 2, 6, 6, 7]],
|
||||
[[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0]],
|
||||
[[0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7]],
|
||||
[[0, 2, 0, 1, 4, 6], [0, 1, 3, 7, 4, 6], [0, 1, 1, 5, 3, 7], [4, 6, 3, 7, 6, 7]],
|
||||
[[7, 3, 6, 7, 6, 4], [6, 4, 4, 0, 7, 3], [7, 3, 4, 0, 5, 1]],
|
||||
[[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5]],
|
||||
[[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5]],
|
||||
[[6, 7, 4, 5, 6, 2], [4, 5, 4, 0, 6, 2], [3, 1, 0, 1, 5, 1]],
|
||||
[[2, 0, 2, 6, 3, 1], [2, 6, 4, 5, 3, 1], [2, 6, 6, 7, 4, 5], [3, 1, 4, 5, 1, 5]],
|
||||
[[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7]],
|
||||
[[0, 1, 2, 3, 6, 7], [0, 1, 6, 7, 4, 5]],
|
||||
[[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7], [1, 3, 0, 1, 1, 5]],
|
||||
[[5, 4, 1, 5, 1, 3], [1, 3, 3, 2, 5, 4], [5, 4, 3, 2, 7, 6]],
|
||||
[[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5], [1, 3, 7, 3, 2, 3]],
|
||||
[[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5], [3, 7, 2, 3, 3, 1]],
|
||||
[[0, 1, 1, 5, 3, 7], [0, 1, 3, 7, 2, 3], [0, 4, 2, 6, 6, 7], [6, 7, 4, 5, 0, 4]],
|
||||
[
|
||||
[6, 2, 7, 6, 5, 4],
|
||||
[5, 4, 0, 2, 6, 2],
|
||||
[5, 1, 0, 2, 5, 4],
|
||||
[5, 1, 3, 2, 0, 2],
|
||||
[5, 1, 7, 3, 3, 2],
|
||||
],
|
||||
[[3, 1, 3, 7, 2, 0], [3, 7, 5, 4, 2, 0], [2, 0, 5, 4, 0, 4], [3, 7, 7, 6, 5, 4]],
|
||||
[[1, 0, 3, 1, 3, 7], [3, 7, 7, 6, 1, 0], [1, 0, 7, 6, 5, 4]],
|
||||
[
|
||||
[1, 0, 5, 1, 7, 3],
|
||||
[7, 3, 2, 0, 1, 0],
|
||||
[7, 6, 2, 0, 7, 3],
|
||||
[7, 6, 4, 0, 2, 0],
|
||||
[7, 6, 5, 4, 4, 0],
|
||||
],
|
||||
[[7, 6, 5, 4, 5, 1], [7, 3, 7, 6, 5, 1]],
|
||||
[[5, 7, 5, 1, 5, 4], [6, 2, 7, 6, 4, 6]],
|
||||
[[0, 2, 0, 4, 1, 0], [5, 4, 5, 7, 1, 5], [2, 6, 7, 6, 4, 6]],
|
||||
[[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3], [2, 6, 7, 6, 4, 6]],
|
||||
[[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3], [6, 7, 4, 6, 6, 2]],
|
||||
[[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0], [1, 5, 4, 5, 7, 5]],
|
||||
[[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3], [5, 1, 4, 5, 5, 7]],
|
||||
[[0, 2, 2, 3, 6, 7], [0, 2, 6, 7, 4, 6], [0, 1, 4, 5, 5, 7], [5, 7, 1, 3, 0, 1]],
|
||||
[
|
||||
[5, 4, 7, 5, 3, 1],
|
||||
[3, 1, 0, 4, 5, 4],
|
||||
[3, 2, 0, 4, 3, 1],
|
||||
[3, 2, 6, 4, 0, 4],
|
||||
[3, 2, 7, 6, 6, 4],
|
||||
],
|
||||
[[5, 4, 5, 7, 1, 5], [3, 7, 3, 2, 1, 3], [4, 6, 2, 6, 7, 6]],
|
||||
[[1, 0, 0, 2, 0, 4], [1, 5, 5, 4, 5, 7], [3, 2, 1, 3, 3, 7], [2, 6, 7, 6, 4, 6]],
|
||||
[[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0], [6, 2, 7, 6, 6, 4]],
|
||||
[
|
||||
[0, 4, 2, 3, 0, 2],
|
||||
[0, 4, 3, 7, 2, 3],
|
||||
[0, 4, 4, 5, 3, 7],
|
||||
[4, 5, 5, 7, 3, 7],
|
||||
[6, 7, 4, 6, 2, 6],
|
||||
],
|
||||
[[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0], [5, 4, 7, 5, 5, 1]],
|
||||
[
|
||||
[0, 1, 4, 6, 0, 4],
|
||||
[0, 1, 6, 7, 4, 6],
|
||||
[0, 1, 1, 3, 6, 7],
|
||||
[1, 3, 3, 7, 6, 7],
|
||||
[5, 7, 1, 5, 4, 5],
|
||||
],
|
||||
[
|
||||
[6, 7, 4, 6, 0, 2],
|
||||
[0, 2, 3, 7, 6, 7],
|
||||
[0, 1, 3, 7, 0, 2],
|
||||
[0, 1, 5, 7, 3, 7],
|
||||
[0, 1, 4, 5, 5, 7],
|
||||
],
|
||||
[[4, 0, 6, 7, 4, 6], [4, 0, 7, 3, 6, 7], [4, 0, 5, 7, 7, 3], [4, 5, 5, 7, 4, 0]],
|
||||
[[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0]],
|
||||
[[0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7]],
|
||||
[[1, 3, 1, 0, 5, 7], [1, 0, 2, 6, 5, 7], [5, 7, 2, 6, 7, 6], [1, 0, 0, 4, 2, 6]],
|
||||
[[2, 0, 6, 2, 6, 7], [6, 7, 7, 5, 2, 0], [2, 0, 7, 5, 3, 1]],
|
||||
[[0, 4, 0, 2, 1, 5], [0, 2, 6, 7, 1, 5], [0, 2, 2, 3, 6, 7], [1, 5, 6, 7, 5, 7]],
|
||||
[[7, 6, 5, 7, 5, 1], [5, 1, 1, 0, 7, 6], [7, 6, 1, 0, 3, 2]],
|
||||
[
|
||||
[2, 0, 3, 2, 7, 6],
|
||||
[7, 6, 4, 0, 2, 0],
|
||||
[7, 5, 4, 0, 7, 6],
|
||||
[7, 5, 1, 0, 4, 0],
|
||||
[7, 5, 3, 1, 1, 0],
|
||||
],
|
||||
[[7, 5, 3, 1, 3, 2], [7, 6, 7, 5, 3, 2]],
|
||||
[[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0], [3, 1, 7, 3, 3, 2]],
|
||||
[
|
||||
[0, 2, 1, 5, 0, 1],
|
||||
[0, 2, 5, 7, 1, 5],
|
||||
[0, 2, 2, 6, 5, 7],
|
||||
[2, 6, 6, 7, 5, 7],
|
||||
[3, 7, 2, 3, 1, 3],
|
||||
],
|
||||
[
|
||||
[3, 7, 2, 3, 0, 1],
|
||||
[0, 1, 5, 7, 3, 7],
|
||||
[0, 4, 5, 7, 0, 1],
|
||||
[0, 4, 6, 7, 5, 7],
|
||||
[0, 4, 2, 6, 6, 7],
|
||||
],
|
||||
[[2, 0, 3, 7, 2, 3], [2, 0, 7, 5, 3, 7], [2, 0, 6, 7, 7, 5], [2, 6, 6, 7, 2, 0]],
|
||||
[
|
||||
[5, 7, 1, 5, 0, 4],
|
||||
[0, 4, 6, 7, 5, 7],
|
||||
[0, 2, 6, 7, 0, 4],
|
||||
[0, 2, 3, 7, 6, 7],
|
||||
[0, 2, 1, 3, 3, 7],
|
||||
],
|
||||
[[1, 0, 5, 7, 1, 5], [1, 0, 7, 6, 5, 7], [1, 0, 3, 7, 7, 6], [1, 3, 3, 7, 1, 0]],
|
||||
[[0, 2, 0, 1, 0, 4], [3, 7, 6, 7, 5, 7]],
|
||||
[[7, 5, 7, 3, 7, 6]],
|
||||
[[7, 3, 7, 5, 7, 6]],
|
||||
[[0, 1, 0, 2, 0, 4], [6, 7, 3, 7, 5, 7]],
|
||||
[[1, 3, 1, 0, 1, 5], [7, 6, 3, 7, 5, 7]],
|
||||
[[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2], [6, 7, 3, 7, 5, 7]],
|
||||
[[2, 6, 2, 0, 2, 3], [7, 5, 6, 7, 3, 7]],
|
||||
[[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4], [5, 7, 6, 7, 3, 7]],
|
||||
[[1, 5, 1, 3, 0, 1], [2, 3, 2, 6, 0, 2], [5, 7, 6, 7, 3, 7]],
|
||||
[[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4], [7, 6, 3, 7, 7, 5]],
|
||||
[[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2]],
|
||||
[[7, 6, 3, 2, 7, 5], [3, 2, 3, 1, 7, 5], [4, 0, 1, 0, 2, 0]],
|
||||
[[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2]],
|
||||
[[2, 3, 2, 0, 6, 7], [2, 0, 1, 5, 6, 7], [2, 0, 0, 4, 1, 5], [6, 7, 1, 5, 7, 5]],
|
||||
[[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1]],
|
||||
[[0, 4, 0, 1, 2, 6], [0, 1, 5, 7, 2, 6], [2, 6, 5, 7, 6, 7], [0, 1, 1, 3, 5, 7]],
|
||||
[[1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6]],
|
||||
[[5, 1, 7, 5, 7, 6], [7, 6, 6, 2, 5, 1], [5, 1, 6, 2, 4, 0]],
|
||||
[[4, 5, 4, 0, 4, 6], [7, 3, 5, 7, 6, 7]],
|
||||
[[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1], [3, 7, 5, 7, 6, 7]],
|
||||
[[4, 6, 4, 5, 0, 4], [1, 5, 1, 3, 0, 1], [6, 7, 3, 7, 5, 7]],
|
||||
[[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2], [7, 3, 5, 7, 7, 6]],
|
||||
[[2, 3, 2, 6, 0, 2], [4, 6, 4, 5, 0, 4], [3, 7, 5, 7, 6, 7]],
|
||||
[[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1], [7, 5, 6, 7, 7, 3]],
|
||||
[[0, 1, 1, 5, 1, 3], [0, 2, 2, 3, 2, 6], [4, 5, 0, 4, 4, 6], [5, 7, 6, 7, 3, 7]],
|
||||
[
|
||||
[1, 3, 5, 4, 1, 5],
|
||||
[1, 3, 4, 6, 5, 4],
|
||||
[1, 3, 3, 2, 4, 6],
|
||||
[3, 2, 2, 6, 4, 6],
|
||||
[7, 6, 3, 7, 5, 7],
|
||||
],
|
||||
[[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2], [0, 4, 6, 4, 5, 4]],
|
||||
[[1, 0, 0, 2, 4, 6], [1, 0, 4, 6, 5, 4], [1, 3, 5, 7, 7, 6], [7, 6, 3, 2, 1, 3]],
|
||||
[[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2], [4, 6, 5, 4, 4, 0]],
|
||||
[
|
||||
[7, 5, 6, 7, 2, 3],
|
||||
[2, 3, 1, 5, 7, 5],
|
||||
[2, 0, 1, 5, 2, 3],
|
||||
[2, 0, 4, 5, 1, 5],
|
||||
[2, 0, 6, 4, 4, 5],
|
||||
],
|
||||
[[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1], [4, 0, 6, 4, 4, 5]],
|
||||
[
|
||||
[4, 6, 5, 4, 1, 0],
|
||||
[1, 0, 2, 6, 4, 6],
|
||||
[1, 3, 2, 6, 1, 0],
|
||||
[1, 3, 7, 6, 2, 6],
|
||||
[1, 3, 5, 7, 7, 6],
|
||||
],
|
||||
[
|
||||
[1, 5, 0, 2, 1, 0],
|
||||
[1, 5, 2, 6, 0, 2],
|
||||
[1, 5, 5, 7, 2, 6],
|
||||
[5, 7, 7, 6, 2, 6],
|
||||
[4, 6, 5, 4, 0, 4],
|
||||
],
|
||||
[[5, 1, 4, 6, 5, 4], [5, 1, 6, 2, 4, 6], [5, 1, 7, 6, 6, 2], [5, 7, 7, 6, 5, 1]],
|
||||
[[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1]],
|
||||
[[7, 3, 5, 1, 7, 6], [5, 1, 5, 4, 7, 6], [2, 0, 4, 0, 1, 0]],
|
||||
[[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4]],
|
||||
[[0, 2, 0, 4, 1, 3], [0, 4, 6, 7, 1, 3], [1, 3, 6, 7, 3, 7], [0, 4, 4, 5, 6, 7]],
|
||||
[[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1], [0, 2, 3, 2, 6, 2]],
|
||||
[[1, 5, 5, 4, 7, 6], [1, 5, 7, 6, 3, 7], [1, 0, 3, 2, 2, 6], [2, 6, 0, 4, 1, 0]],
|
||||
[[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4], [2, 0, 3, 2, 2, 6]],
|
||||
[
|
||||
[2, 3, 6, 2, 4, 0],
|
||||
[4, 0, 1, 3, 2, 3],
|
||||
[4, 5, 1, 3, 4, 0],
|
||||
[4, 5, 7, 3, 1, 3],
|
||||
[4, 5, 6, 7, 7, 3],
|
||||
],
|
||||
[[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6]],
|
||||
[[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6], [0, 4, 1, 0, 0, 2]],
|
||||
[[1, 0, 5, 4, 7, 6], [1, 0, 7, 6, 3, 2]],
|
||||
[[2, 3, 0, 2, 0, 4], [0, 4, 4, 5, 2, 3], [2, 3, 4, 5, 6, 7]],
|
||||
[[1, 3, 1, 5, 0, 2], [1, 5, 7, 6, 0, 2], [1, 5, 5, 4, 7, 6], [0, 2, 7, 6, 2, 6]],
|
||||
[
|
||||
[5, 1, 4, 5, 6, 7],
|
||||
[6, 7, 3, 1, 5, 1],
|
||||
[6, 2, 3, 1, 6, 7],
|
||||
[6, 2, 0, 1, 3, 1],
|
||||
[6, 2, 4, 0, 0, 1],
|
||||
],
|
||||
[[6, 7, 2, 6, 2, 0], [2, 0, 0, 1, 6, 7], [6, 7, 0, 1, 4, 5]],
|
||||
[[6, 2, 4, 0, 4, 5], [6, 7, 6, 2, 4, 5]],
|
||||
[[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1]],
|
||||
[[1, 5, 1, 0, 3, 7], [1, 0, 4, 6, 3, 7], [1, 0, 0, 2, 4, 6], [3, 7, 4, 6, 7, 6]],
|
||||
[[1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6]],
|
||||
[[6, 4, 7, 6, 7, 3], [7, 3, 3, 1, 6, 4], [6, 4, 3, 1, 2, 0]],
|
||||
[[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1], [2, 3, 6, 2, 2, 0]],
|
||||
[
|
||||
[7, 6, 3, 7, 1, 5],
|
||||
[1, 5, 4, 6, 7, 6],
|
||||
[1, 0, 4, 6, 1, 5],
|
||||
[1, 0, 2, 6, 4, 6],
|
||||
[1, 0, 3, 2, 2, 6],
|
||||
],
|
||||
[
|
||||
[1, 0, 3, 7, 1, 3],
|
||||
[1, 0, 7, 6, 3, 7],
|
||||
[1, 0, 0, 4, 7, 6],
|
||||
[0, 4, 4, 6, 7, 6],
|
||||
[2, 6, 0, 2, 3, 2],
|
||||
],
|
||||
[[3, 1, 7, 6, 3, 7], [3, 1, 6, 4, 7, 6], [3, 1, 2, 6, 6, 4], [3, 2, 2, 6, 3, 1]],
|
||||
[[3, 2, 3, 1, 7, 6], [3, 1, 0, 4, 7, 6], [7, 6, 0, 4, 6, 4], [3, 1, 1, 5, 0, 4]],
|
||||
[
|
||||
[0, 1, 2, 0, 6, 4],
|
||||
[6, 4, 5, 1, 0, 1],
|
||||
[6, 7, 5, 1, 6, 4],
|
||||
[6, 7, 3, 1, 5, 1],
|
||||
[6, 7, 2, 3, 3, 1],
|
||||
],
|
||||
[[0, 1, 4, 0, 4, 6], [4, 6, 6, 7, 0, 1], [0, 1, 6, 7, 2, 3]],
|
||||
[[6, 7, 2, 3, 2, 0], [6, 4, 6, 7, 2, 0]],
|
||||
[
|
||||
[2, 6, 0, 2, 1, 3],
|
||||
[1, 3, 7, 6, 2, 6],
|
||||
[1, 5, 7, 6, 1, 3],
|
||||
[1, 5, 4, 6, 7, 6],
|
||||
[1, 5, 0, 4, 4, 6],
|
||||
],
|
||||
[[1, 5, 1, 0, 1, 3], [4, 6, 7, 6, 2, 6]],
|
||||
[[0, 1, 2, 6, 0, 2], [0, 1, 6, 7, 2, 6], [0, 1, 4, 6, 6, 7], [0, 4, 4, 6, 0, 1]],
|
||||
[[6, 7, 6, 2, 6, 4]],
|
||||
[[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4]],
|
||||
[[7, 5, 6, 4, 7, 3], [6, 4, 6, 2, 7, 3], [1, 0, 2, 0, 4, 0]],
|
||||
[[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4], [0, 1, 5, 1, 3, 1]],
|
||||
[[2, 0, 0, 4, 1, 5], [2, 0, 1, 5, 3, 1], [2, 6, 3, 7, 7, 5], [7, 5, 6, 4, 2, 6]],
|
||||
[[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4]],
|
||||
[[3, 2, 3, 7, 1, 0], [3, 7, 6, 4, 1, 0], [3, 7, 7, 5, 6, 4], [1, 0, 6, 4, 0, 4]],
|
||||
[[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4], [1, 5, 3, 1, 1, 0]],
|
||||
[
|
||||
[7, 3, 5, 7, 4, 6],
|
||||
[4, 6, 2, 3, 7, 3],
|
||||
[4, 0, 2, 3, 4, 6],
|
||||
[4, 0, 1, 3, 2, 3],
|
||||
[4, 0, 5, 1, 1, 3],
|
||||
],
|
||||
[[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5]],
|
||||
[[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5], [0, 1, 2, 0, 0, 4]],
|
||||
[[1, 0, 1, 5, 3, 2], [1, 5, 4, 6, 3, 2], [3, 2, 4, 6, 2, 6], [1, 5, 5, 7, 4, 6]],
|
||||
[
|
||||
[0, 2, 4, 0, 5, 1],
|
||||
[5, 1, 3, 2, 0, 2],
|
||||
[5, 7, 3, 2, 5, 1],
|
||||
[5, 7, 6, 2, 3, 2],
|
||||
[5, 7, 4, 6, 6, 2],
|
||||
],
|
||||
[[2, 0, 3, 1, 7, 5], [2, 0, 7, 5, 6, 4]],
|
||||
[[4, 6, 0, 4, 0, 1], [0, 1, 1, 3, 4, 6], [4, 6, 1, 3, 5, 7]],
|
||||
[[0, 2, 1, 0, 1, 5], [1, 5, 5, 7, 0, 2], [0, 2, 5, 7, 4, 6]],
|
||||
[[5, 7, 4, 6, 4, 0], [5, 1, 5, 7, 4, 0]],
|
||||
[[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2]],
|
||||
[[0, 1, 0, 2, 4, 5], [0, 2, 3, 7, 4, 5], [4, 5, 3, 7, 5, 7], [0, 2, 2, 6, 3, 7]],
|
||||
[[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2], [1, 0, 5, 1, 1, 3]],
|
||||
[
|
||||
[1, 5, 3, 1, 2, 0],
|
||||
[2, 0, 4, 5, 1, 5],
|
||||
[2, 6, 4, 5, 2, 0],
|
||||
[2, 6, 7, 5, 4, 5],
|
||||
[2, 6, 3, 7, 7, 5],
|
||||
],
|
||||
[[2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5]],
|
||||
[[3, 2, 7, 3, 7, 5], [7, 5, 5, 4, 3, 2], [3, 2, 5, 4, 1, 0]],
|
||||
[
|
||||
[2, 3, 0, 4, 2, 0],
|
||||
[2, 3, 4, 5, 0, 4],
|
||||
[2, 3, 3, 7, 4, 5],
|
||||
[3, 7, 7, 5, 4, 5],
|
||||
[1, 5, 3, 1, 0, 1],
|
||||
],
|
||||
[[3, 2, 1, 5, 3, 1], [3, 2, 5, 4, 1, 5], [3, 2, 7, 5, 5, 4], [3, 7, 7, 5, 3, 2]],
|
||||
[[2, 6, 2, 3, 0, 4], [2, 3, 7, 5, 0, 4], [2, 3, 3, 1, 7, 5], [0, 4, 7, 5, 4, 5]],
|
||||
[
|
||||
[3, 2, 1, 3, 5, 7],
|
||||
[5, 7, 6, 2, 3, 2],
|
||||
[5, 4, 6, 2, 5, 7],
|
||||
[5, 4, 0, 2, 6, 2],
|
||||
[5, 4, 1, 0, 0, 2],
|
||||
],
|
||||
[
|
||||
[4, 5, 0, 4, 2, 6],
|
||||
[2, 6, 7, 5, 4, 5],
|
||||
[2, 3, 7, 5, 2, 6],
|
||||
[2, 3, 1, 5, 7, 5],
|
||||
[2, 3, 0, 1, 1, 5],
|
||||
],
|
||||
[[2, 3, 2, 0, 2, 6], [1, 5, 7, 5, 4, 5]],
|
||||
[[5, 7, 4, 5, 4, 0], [4, 0, 0, 2, 5, 7], [5, 7, 0, 2, 1, 3]],
|
||||
[[5, 4, 1, 0, 1, 3], [5, 7, 5, 4, 1, 3]],
|
||||
[[0, 2, 4, 5, 0, 4], [0, 2, 5, 7, 4, 5], [0, 2, 1, 5, 5, 7], [0, 1, 1, 5, 0, 2]],
|
||||
[[5, 4, 5, 1, 5, 7]],
|
||||
[[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3]],
|
||||
[[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3], [0, 2, 4, 0, 0, 1]],
|
||||
[[3, 7, 3, 1, 2, 6], [3, 1, 5, 4, 2, 6], [3, 1, 1, 0, 5, 4], [2, 6, 5, 4, 6, 4]],
|
||||
[
|
||||
[6, 4, 2, 6, 3, 7],
|
||||
[3, 7, 5, 4, 6, 4],
|
||||
[3, 1, 5, 4, 3, 7],
|
||||
[3, 1, 0, 4, 5, 4],
|
||||
[3, 1, 2, 0, 0, 4],
|
||||
],
|
||||
[[2, 0, 2, 3, 6, 4], [2, 3, 1, 5, 6, 4], [6, 4, 1, 5, 4, 5], [2, 3, 3, 7, 1, 5]],
|
||||
[
|
||||
[0, 4, 1, 0, 3, 2],
|
||||
[3, 2, 6, 4, 0, 4],
|
||||
[3, 7, 6, 4, 3, 2],
|
||||
[3, 7, 5, 4, 6, 4],
|
||||
[3, 7, 1, 5, 5, 4],
|
||||
],
|
||||
[
|
||||
[1, 3, 0, 1, 4, 5],
|
||||
[4, 5, 7, 3, 1, 3],
|
||||
[4, 6, 7, 3, 4, 5],
|
||||
[4, 6, 2, 3, 7, 3],
|
||||
[4, 6, 0, 2, 2, 3],
|
||||
],
|
||||
[[3, 7, 3, 1, 3, 2], [5, 4, 6, 4, 0, 4]],
|
||||
[[3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4]],
|
||||
[
|
||||
[3, 1, 2, 6, 3, 2],
|
||||
[3, 1, 6, 4, 2, 6],
|
||||
[3, 1, 1, 5, 6, 4],
|
||||
[1, 5, 5, 4, 6, 4],
|
||||
[0, 4, 1, 0, 2, 0],
|
||||
],
|
||||
[[4, 5, 6, 4, 6, 2], [6, 2, 2, 3, 4, 5], [4, 5, 2, 3, 0, 1]],
|
||||
[[2, 3, 6, 4, 2, 6], [2, 3, 4, 5, 6, 4], [2, 3, 0, 4, 4, 5], [2, 0, 0, 4, 2, 3]],
|
||||
[[1, 3, 5, 1, 5, 4], [5, 4, 4, 6, 1, 3], [1, 3, 4, 6, 0, 2]],
|
||||
[[1, 3, 0, 4, 1, 0], [1, 3, 4, 6, 0, 4], [1, 3, 5, 4, 4, 6], [1, 5, 5, 4, 1, 3]],
|
||||
[[4, 6, 0, 2, 0, 1], [4, 5, 4, 6, 0, 1]],
|
||||
[[4, 6, 4, 0, 4, 5]],
|
||||
[[4, 0, 6, 2, 7, 3], [4, 0, 7, 3, 5, 1]],
|
||||
[[1, 5, 0, 1, 0, 2], [0, 2, 2, 6, 1, 5], [1, 5, 2, 6, 3, 7]],
|
||||
[[3, 7, 1, 3, 1, 0], [1, 0, 0, 4, 3, 7], [3, 7, 0, 4, 2, 6]],
|
||||
[[3, 1, 2, 0, 2, 6], [3, 7, 3, 1, 2, 6]],
|
||||
[[0, 4, 2, 0, 2, 3], [2, 3, 3, 7, 0, 4], [0, 4, 3, 7, 1, 5]],
|
||||
[[3, 7, 1, 5, 1, 0], [3, 2, 3, 7, 1, 0]],
|
||||
[[0, 4, 1, 3, 0, 1], [0, 4, 3, 7, 1, 3], [0, 4, 2, 3, 3, 7], [0, 2, 2, 3, 0, 4]],
|
||||
[[3, 7, 3, 1, 3, 2]],
|
||||
[[2, 6, 3, 2, 3, 1], [3, 1, 1, 5, 2, 6], [2, 6, 1, 5, 0, 4]],
|
||||
[[1, 5, 3, 2, 1, 3], [1, 5, 2, 6, 3, 2], [1, 5, 0, 2, 2, 6], [1, 0, 0, 2, 1, 5]],
|
||||
[[2, 3, 0, 1, 0, 4], [2, 6, 2, 3, 0, 4]],
|
||||
[[2, 3, 2, 0, 2, 6]],
|
||||
[[1, 5, 0, 4, 0, 2], [1, 3, 1, 5, 0, 2]],
|
||||
[[1, 5, 1, 0, 1, 3]],
|
||||
[[0, 2, 0, 1, 0, 4]],
|
||||
[],
|
||||
]
|
||||
|
||||
|
||||
def create_mc_lookup_table():
|
||||
cases = torch.zeros(256, 5, 3, dtype=torch.long)
|
||||
masks = torch.zeros(256, 5, dtype=torch.bool)
|
||||
|
||||
edge_to_index = {
|
||||
(0, 1): 0,
|
||||
(2, 3): 1,
|
||||
(4, 5): 2,
|
||||
(6, 7): 3,
|
||||
(0, 2): 4,
|
||||
(1, 3): 5,
|
||||
(4, 6): 6,
|
||||
(5, 7): 7,
|
||||
(0, 4): 8,
|
||||
(1, 5): 9,
|
||||
(2, 6): 10,
|
||||
(3, 7): 11,
|
||||
}
|
||||
|
||||
for i, case in enumerate(MC_TABLE):
|
||||
for j, tri in enumerate(case):
|
||||
for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])):
|
||||
cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)]
|
||||
masks[i, j] = True
|
||||
return cases, masks
|
||||
|
||||
|
||||
RENDERER_CONFIG = {}
|
||||
|
||||
|
||||
@@ -400,7 +881,12 @@ def renderer_model_original_checkpoint_to_diffusers_checkpoint(model, checkpoint
|
||||
}
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update({"void.background": torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)})
|
||||
diffusers_checkpoint.update({"void.background": model.state_dict()["void.background"]})
|
||||
|
||||
cases, masks = create_mc_lookup_table()
|
||||
|
||||
diffusers_checkpoint.update({"mesh_decoder.cases": cases})
|
||||
diffusers_checkpoint.update({"mesh_decoder.masks": masks})
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`HeunDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
renderer ([`ShapERenderer`]):
|
||||
shap_e_renderer ([`ShapERenderer`]):
|
||||
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
|
||||
with the NeRF rendering method
|
||||
"""
|
||||
@@ -106,7 +106,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
scheduler: HeunDiscreteScheduler,
|
||||
renderer: ShapERenderer,
|
||||
shap_e_renderer: ShapERenderer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -115,7 +115,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
renderer=renderer,
|
||||
shap_e_renderer=shap_e_renderer,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
@@ -149,7 +149,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]:
|
||||
for cpu_offloaded_model in [self.text_encoder, self.prior, self.shap_e_renderer]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
@@ -218,7 +218,7 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
frame_size: int = 64,
|
||||
output_type: Optional[str] = "pil", # pil, np, latent
|
||||
output_type: Optional[str] = "pil", # pil, np, latent, mesh
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -248,8 +248,8 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
frame_size (`int`, *optional*, default to 64):
|
||||
the width and height of each image frame of the generated 3d output
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -319,30 +319,39 @@ class ShapEPipeline(DiffusionPipeline):
|
||||
sample=latents,
|
||||
).prev_sample
|
||||
|
||||
if output_type not in ["np", "pil", "latent", "mesh"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
|
||||
)
|
||||
|
||||
if output_type == "latent":
|
||||
return ShapEPipelineOutput(images=latents)
|
||||
|
||||
images = []
|
||||
for i, latent in enumerate(latents):
|
||||
image = self.renderer.decode(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
ray_batch_size=4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
)
|
||||
images.append(image)
|
||||
if output_type == "mesh":
|
||||
for i, latent in enumerate(latents):
|
||||
mesh = self.shap_e_renderer.decode_to_mesh(
|
||||
latent[None, :],
|
||||
device,
|
||||
)
|
||||
images.append(mesh)
|
||||
|
||||
images = torch.stack(images)
|
||||
else:
|
||||
# np, pil
|
||||
for i, latent in enumerate(latents):
|
||||
image = self.shap_e_renderer.decode_to_image(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
)
|
||||
images.append(image)
|
||||
|
||||
if output_type not in ["np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
|
||||
images = torch.stack(images)
|
||||
|
||||
images = images.cpu().numpy()
|
||||
images = images.cpu().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
|
||||
@@ -94,7 +94,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`HeunDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
renderer ([`ShapERenderer`]):
|
||||
shap_e_renderer ([`ShapERenderer`]):
|
||||
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
|
||||
with the NeRF rendering method
|
||||
"""
|
||||
@@ -105,7 +105,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_processor: CLIPImageProcessor,
|
||||
scheduler: HeunDiscreteScheduler,
|
||||
renderer: ShapERenderer,
|
||||
shap_e_renderer: ShapERenderer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -114,7 +114,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
scheduler=scheduler,
|
||||
renderer=renderer,
|
||||
shap_e_renderer=shap_e_renderer,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
@@ -170,7 +170,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
frame_size: int = 64,
|
||||
output_type: Optional[str] = "pil", # pil, np, latent
|
||||
output_type: Optional[str] = "pil", # pil, np, latent, mesh
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -200,8 +200,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
frame_size (`int`, *optional*, default to 64):
|
||||
the width and height of each image frame of the generated 3d output
|
||||
output_type (`str`, *optional*, defaults to `"pt"`):
|
||||
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
||||
(`torch.Tensor`).
|
||||
(`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -275,32 +274,39 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
|
||||
sample=latents,
|
||||
).prev_sample
|
||||
|
||||
if output_type not in ["np", "pil", "latent", "mesh"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
|
||||
)
|
||||
|
||||
if output_type == "latent":
|
||||
return ShapEPipelineOutput(images=latents)
|
||||
|
||||
images = []
|
||||
for i, latent in enumerate(latents):
|
||||
print()
|
||||
image = self.renderer.decode(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
ray_batch_size=4096,
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
)
|
||||
if output_type == "mesh":
|
||||
for i, latent in enumerate(latents):
|
||||
mesh = self.shap_e_renderer.decode_to_mesh(
|
||||
latent[None, :],
|
||||
device,
|
||||
)
|
||||
images.append(mesh)
|
||||
|
||||
images.append(image)
|
||||
else:
|
||||
# np, pil
|
||||
for i, latent in enumerate(latents):
|
||||
image = self.shap_e_renderer.decode_to_image(
|
||||
latent[None, :],
|
||||
device,
|
||||
size=frame_size,
|
||||
)
|
||||
images.append(image)
|
||||
|
||||
images = torch.stack(images)
|
||||
images = torch.stack(images)
|
||||
|
||||
if output_type not in ["np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
|
||||
images = images.cpu().numpy()
|
||||
|
||||
images = images.cpu().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
if output_type == "pil":
|
||||
images = [self.numpy_to_pil(image) for image in images]
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -116,6 +116,101 @@ def integrate_samples(volume_range, ts, density, channels):
|
||||
return channels, weights, transmittance
|
||||
|
||||
|
||||
def volume_query_points(volume, grid_size):
|
||||
indices = torch.arange(grid_size**3, device=volume.bbox_min.device)
|
||||
zs = indices % grid_size
|
||||
ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size
|
||||
xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size
|
||||
combined = torch.stack([xs, ys, zs], dim=1)
|
||||
return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min
|
||||
|
||||
|
||||
def _convert_srgb_to_linear(u: torch.Tensor):
|
||||
return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4)
|
||||
|
||||
|
||||
def _create_flat_edge_indices(
|
||||
flat_cube_indices: torch.Tensor,
|
||||
grid_size: Tuple[int, int, int],
|
||||
):
|
||||
num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2]
|
||||
y_offset = num_xs
|
||||
num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2]
|
||||
z_offset = num_xs + num_ys
|
||||
return torch.stack(
|
||||
[
|
||||
# Edges spanning x-axis.
|
||||
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2],
|
||||
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||||
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 2],
|
||||
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
+ 1,
|
||||
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
|
||||
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
+ 1,
|
||||
# Edges spanning y-axis.
|
||||
(
|
||||
y_offset
|
||||
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
y_offset
|
||||
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
y_offset
|
||||
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
+ 1
|
||||
),
|
||||
(
|
||||
y_offset
|
||||
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
|
||||
+ flat_cube_indices[:, 1] * grid_size[2]
|
||||
+ flat_cube_indices[:, 2]
|
||||
+ 1
|
||||
),
|
||||
# Edges spanning z-axis.
|
||||
(
|
||||
z_offset
|
||||
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
z_offset
|
||||
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
z_offset
|
||||
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
|
||||
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
(
|
||||
z_offset
|
||||
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
|
||||
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
|
||||
+ flat_cube_indices[:, 2]
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
class VoidNeRFModel(nn.Module):
|
||||
"""
|
||||
Implements the default empty space model where all queries are rendered as background.
|
||||
@@ -368,6 +463,141 @@ class ImportanceRaySampler(nn.Module):
|
||||
return ts
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshDecoderOutput(BaseOutput):
|
||||
"""
|
||||
A 3D triangle mesh with optional data at the vertices and faces.
|
||||
|
||||
Args:
|
||||
verts (`torch.Tensor` of shape `(N, 3)`):
|
||||
array of vertext coordinates
|
||||
faces (`torch.Tensor` of shape `(N, 3)`):
|
||||
array of triangles, pointing to indices in verts.
|
||||
vertext_channels (Dict):
|
||||
vertext coordinates for each color channel
|
||||
"""
|
||||
|
||||
verts: torch.Tensor
|
||||
faces: torch.Tensor
|
||||
vertex_channels: Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class MeshDecoder(nn.Module):
|
||||
"""
|
||||
Construct meshes from Signed distance functions (SDFs) using marching cubes method
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
cases = torch.zeros(256, 5, 3, dtype=torch.long)
|
||||
masks = torch.zeros(256, 5, dtype=torch.bool)
|
||||
|
||||
self.register_buffer("cases", cases)
|
||||
self.register_buffer("masks", masks)
|
||||
|
||||
def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor):
|
||||
"""
|
||||
For a signed distance field, produce a mesh using marching cubes.
|
||||
|
||||
:param field: a 3D tensor of field values, where negative values correspond
|
||||
to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively.
|
||||
:param min_point: a tensor of shape [3] containing the point corresponding
|
||||
to (0, 0, 0) in the field.
|
||||
:param size: a tensor of shape [3] containing the per-axis distance from the
|
||||
(0, 0, 0) field corner and the (-1, -1, -1) field corner.
|
||||
"""
|
||||
assert len(field.shape) == 3, "input must be a 3D scalar field"
|
||||
dev = field.device
|
||||
|
||||
cases = self.cases.to(dev)
|
||||
masks = self.masks.to(dev)
|
||||
|
||||
min_point = min_point.to(dev)
|
||||
size = size.to(dev)
|
||||
|
||||
grid_size = field.shape
|
||||
grid_size_tensor = torch.tensor(grid_size).to(size)
|
||||
|
||||
# Create bitmasks between 0 and 255 (inclusive) indicating the state
|
||||
# of the eight corners of each cube.
|
||||
bitmasks = (field > 0).to(torch.uint8)
|
||||
bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1)
|
||||
bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2)
|
||||
bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4)
|
||||
|
||||
# Compute corner coordinates across the entire grid.
|
||||
corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype)
|
||||
corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[
|
||||
:, None, None
|
||||
]
|
||||
corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[
|
||||
:, None
|
||||
]
|
||||
corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype)
|
||||
|
||||
# Compute all vertices across all edges in the grid, even though we will
|
||||
# throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices.
|
||||
# These are all midpoints, and don't account for interpolation (which is
|
||||
# done later based on the used edge midpoints).
|
||||
edge_midpoints = torch.cat(
|
||||
[
|
||||
((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3),
|
||||
((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3),
|
||||
((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Create a flat array of [X, Y, Z] indices for each cube.
|
||||
cube_indices = torch.zeros(
|
||||
grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long
|
||||
)
|
||||
cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None]
|
||||
cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None]
|
||||
cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev)
|
||||
flat_cube_indices = cube_indices.reshape(-1, 3)
|
||||
|
||||
# Create a flat array mapping each cube to 12 global edge indices.
|
||||
edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size)
|
||||
|
||||
# Apply the LUT to figure out the triangles.
|
||||
flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask
|
||||
local_tris = cases[flat_bitmasks]
|
||||
local_masks = masks[flat_bitmasks]
|
||||
# Compute the global edge indices for the triangles.
|
||||
global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape(
|
||||
local_tris.shape
|
||||
)
|
||||
# Select the used triangles for each cube.
|
||||
selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)]
|
||||
|
||||
# Now we have a bunch of indices into the full list of possible vertices,
|
||||
# but we want to reduce this list to only the used vertices.
|
||||
used_vertex_indices = torch.unique(selected_tris.view(-1))
|
||||
used_edge_midpoints = edge_midpoints[used_vertex_indices]
|
||||
old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long)
|
||||
old_index_to_new_index[used_vertex_indices] = torch.arange(
|
||||
len(used_vertex_indices), device=dev, dtype=torch.long
|
||||
)
|
||||
|
||||
# Rewrite the triangles to use the new indices
|
||||
faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape)
|
||||
|
||||
# Compute the actual interpolated coordinates corresponding to edge midpoints.
|
||||
v1 = torch.floor(used_edge_midpoints).to(torch.long)
|
||||
v2 = torch.ceil(used_edge_midpoints).to(torch.long)
|
||||
s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]]
|
||||
s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]]
|
||||
p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point
|
||||
p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point
|
||||
# The signs of s1 and s2 should be different. We want to find
|
||||
# t such that t*s2 + (1-t)*s1 = 0.
|
||||
t = (s1 / (s1 - s2))[:, None]
|
||||
verts = t * p2 + (1 - t) * p1
|
||||
|
||||
return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLPNeRFModelOutput(BaseOutput):
|
||||
density: torch.Tensor
|
||||
@@ -429,7 +659,7 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin):
|
||||
|
||||
return mapped_output
|
||||
|
||||
def forward(self, *, position, direction, ts, nerf_level="coarse"):
|
||||
def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"):
|
||||
h = encode_position(position)
|
||||
|
||||
h_preact = h
|
||||
@@ -455,10 +685,17 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin):
|
||||
|
||||
if nerf_level == "coarse":
|
||||
h_density = activation["density_coarse"]
|
||||
h_channels = activation["nerf_coarse"]
|
||||
else:
|
||||
h_density = activation["density_fine"]
|
||||
h_channels = activation["nerf_fine"]
|
||||
|
||||
if rendering_mode == "nerf":
|
||||
if nerf_level == "coarse":
|
||||
h_channels = activation["nerf_coarse"]
|
||||
else:
|
||||
h_channels = activation["nerf_fine"]
|
||||
|
||||
elif rendering_mode == "stf":
|
||||
h_channels = activation["stf"]
|
||||
|
||||
density = self.density_activation(h_density)
|
||||
signed_distance = self.sdf_activation(activation["sdf"])
|
||||
@@ -583,6 +820,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
|
||||
self.void = VoidNeRFModel(background=background, channel_scale=255.0)
|
||||
self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
|
||||
self.mesh_decoder = MeshDecoder()
|
||||
|
||||
@torch.no_grad()
|
||||
def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False):
|
||||
@@ -664,7 +902,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
return channels, weighted_sampler, model_out
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
def decode_to_image(
|
||||
self,
|
||||
latents,
|
||||
device,
|
||||
@@ -707,3 +945,106 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)
|
||||
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_to_mesh(
|
||||
self,
|
||||
latents,
|
||||
device,
|
||||
grid_size: int = 128,
|
||||
query_batch_size: int = 4096,
|
||||
texture_channels: Tuple = ("R", "G", "B"),
|
||||
):
|
||||
# 1. project the the paramters from the generated latents
|
||||
projected_params = self.params_proj(latents)
|
||||
|
||||
# 2. update the mlp layers of the renderer
|
||||
for name, param in self.mlp.state_dict().items():
|
||||
if f"nerstf.{name}" in projected_params.keys():
|
||||
param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))
|
||||
|
||||
# 3. decoding with STF rendering
|
||||
# 3.1 query the SDF values at vertices along a regular 128**3 grid
|
||||
|
||||
query_points = volume_query_points(self.volume, grid_size)
|
||||
query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype)
|
||||
|
||||
fields = []
|
||||
|
||||
for idx in range(0, query_positions.shape[1], query_batch_size):
|
||||
query_batch = query_positions[:, idx : idx + query_batch_size]
|
||||
|
||||
model_out = self.mlp(
|
||||
position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf"
|
||||
)
|
||||
fields.append(model_out.signed_distance)
|
||||
|
||||
# predicted SDF values
|
||||
fields = torch.cat(fields, dim=1)
|
||||
fields = fields.float()
|
||||
|
||||
assert (
|
||||
len(fields.shape) == 3 and fields.shape[-1] == 1
|
||||
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
|
||||
|
||||
fields = fields.reshape(1, *([grid_size] * 3))
|
||||
|
||||
# create grid 128 x 128 x 128
|
||||
# - force a negative border around the SDFs to close off all the models.
|
||||
full_grid = torch.zeros(
|
||||
1,
|
||||
grid_size + 2,
|
||||
grid_size + 2,
|
||||
grid_size + 2,
|
||||
device=fields.device,
|
||||
dtype=fields.dtype,
|
||||
)
|
||||
full_grid.fill_(-1.0)
|
||||
full_grid[:, 1:-1, 1:-1, 1:-1] = fields
|
||||
fields = full_grid
|
||||
|
||||
# apply a differentiable implementation of Marching Cubes to construct meshs
|
||||
raw_meshes = []
|
||||
mesh_mask = []
|
||||
|
||||
for field in fields:
|
||||
raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min)
|
||||
mesh_mask.append(True)
|
||||
raw_meshes.append(raw_mesh)
|
||||
|
||||
mesh_mask = torch.tensor(mesh_mask, device=fields.device)
|
||||
max_vertices = max(len(m.verts) for m in raw_meshes)
|
||||
|
||||
# 3.2. query the texture color head at each vertex of the resulting mesh.
|
||||
texture_query_positions = torch.stack(
|
||||
[m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes],
|
||||
dim=0,
|
||||
)
|
||||
texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype)
|
||||
|
||||
textures = []
|
||||
|
||||
for idx in range(0, texture_query_positions.shape[1], query_batch_size):
|
||||
query_batch = texture_query_positions[:, idx : idx + query_batch_size]
|
||||
|
||||
texture_model_out = self.mlp(
|
||||
position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf"
|
||||
)
|
||||
textures.append(texture_model_out.channels)
|
||||
|
||||
# predict texture color
|
||||
textures = torch.cat(textures, dim=1)
|
||||
|
||||
textures = _convert_srgb_to_linear(textures)
|
||||
textures = textures.float()
|
||||
|
||||
# 3.3 augument the mesh with texture data
|
||||
assert len(textures.shape) == 3 and textures.shape[-1] == len(
|
||||
texture_channels
|
||||
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
|
||||
|
||||
for m, texture in zip(raw_meshes, textures):
|
||||
texture = texture[: len(m.verts)]
|
||||
m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1)))
|
||||
|
||||
return raw_meshes[0]
|
||||
|
||||
@@ -103,7 +103,7 @@ if is_torch_available():
|
||||
)
|
||||
from .torch_utils import maybe_allow_in_graph
|
||||
|
||||
from .testing_utils import export_to_gif, export_to_video
|
||||
from .testing_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import struct
|
||||
import tempfile
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from contextlib import contextmanager
|
||||
from distutils.util import strtobool
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
@@ -315,6 +318,85 @@ def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) ->
|
||||
return output_gif_path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def buffered_writer(raw_f):
|
||||
f = io.BufferedWriter(raw_f)
|
||||
yield f
|
||||
f.flush()
|
||||
|
||||
|
||||
def export_to_ply(mesh, output_ply_path: str = None):
|
||||
"""
|
||||
Write a PLY file for a mesh.
|
||||
"""
|
||||
if output_ply_path is None:
|
||||
output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name
|
||||
|
||||
coords = mesh.verts.detach().cpu().numpy()
|
||||
faces = mesh.faces.cpu().numpy()
|
||||
rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
|
||||
|
||||
with buffered_writer(open(output_ply_path, "wb")) as f:
|
||||
f.write(b"ply\n")
|
||||
f.write(b"format binary_little_endian 1.0\n")
|
||||
f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
|
||||
f.write(b"property float x\n")
|
||||
f.write(b"property float y\n")
|
||||
f.write(b"property float z\n")
|
||||
if rgb is not None:
|
||||
f.write(b"property uchar red\n")
|
||||
f.write(b"property uchar green\n")
|
||||
f.write(b"property uchar blue\n")
|
||||
if faces is not None:
|
||||
f.write(bytes(f"element face {len(faces)}\n", "ascii"))
|
||||
f.write(b"property list uchar int vertex_index\n")
|
||||
f.write(b"end_header\n")
|
||||
|
||||
if rgb is not None:
|
||||
rgb = (rgb * 255.499).round().astype(int)
|
||||
vertices = [
|
||||
(*coord, *rgb)
|
||||
for coord, rgb in zip(
|
||||
coords.tolist(),
|
||||
rgb.tolist(),
|
||||
)
|
||||
]
|
||||
format = struct.Struct("<3f3B")
|
||||
for item in vertices:
|
||||
f.write(format.pack(*item))
|
||||
else:
|
||||
format = struct.Struct("<3f")
|
||||
for vertex in coords.tolist():
|
||||
f.write(format.pack(*vertex))
|
||||
|
||||
if faces is not None:
|
||||
format = struct.Struct("<B3I")
|
||||
for tri in faces.tolist():
|
||||
f.write(format.pack(len(tri), *tri))
|
||||
|
||||
return output_ply_path
|
||||
|
||||
|
||||
def export_to_obj(mesh, output_obj_path: str = None):
|
||||
if output_obj_path is None:
|
||||
output_obj_path = tempfile.NamedTemporaryFile(suffix=".obj").name
|
||||
|
||||
verts = mesh.verts.detach().cpu().numpy()
|
||||
faces = mesh.faces.cpu().numpy()
|
||||
|
||||
vertex_colors = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
|
||||
vertices = [
|
||||
"{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(verts.tolist(), vertex_colors.tolist())
|
||||
]
|
||||
|
||||
faces = ["f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in faces.tolist()]
|
||||
|
||||
combined_data = ["v " + vertex for vertex in vertices] + faces
|
||||
|
||||
with open(output_obj_path, "w") as f:
|
||||
f.writelines("\n".join(combined_data))
|
||||
|
||||
|
||||
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
|
||||
@@ -131,7 +131,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
prior = self.dummy_prior
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
renderer = self.dummy_renderer
|
||||
shap_e_renderer = self.dummy_renderer
|
||||
|
||||
scheduler = HeunDiscreteScheduler(
|
||||
beta_schedule="exp",
|
||||
@@ -145,7 +145,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"prior": prior,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"renderer": renderer,
|
||||
"shap_e_renderer": shap_e_renderer,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
prior = self.dummy_prior
|
||||
image_encoder = self.dummy_image_encoder
|
||||
image_processor = self.dummy_image_processor
|
||||
renderer = self.dummy_renderer
|
||||
shap_e_renderer = self.dummy_renderer
|
||||
|
||||
scheduler = HeunDiscreteScheduler(
|
||||
beta_schedule="exp",
|
||||
@@ -157,7 +157,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"prior": prior,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
"renderer": renderer,
|
||||
"shap_e_renderer": shap_e_renderer,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user