diff --git a/docs/source/en/api/pipelines/shap_e.mdx b/docs/source/en/api/pipelines/shap_e.mdx index fcb32da31b..2eec12e6a6 100644 --- a/docs/source/en/api/pipelines/shap_e.mdx +++ b/docs/source/en/api/pipelines/shap_e.mdx @@ -128,6 +128,63 @@ gif_path = export_to_gif(images[0], "burger_3d.gif") ``` ![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_out.gif) +### Generate mesh + +For both [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`], you can generate mesh output by passing `output_type` as `mesh` to the pipeline, and then use the [`ShapEPipeline.export_to_ply`] utility function to save the output as a `ply` file. We also provide a [`ShapEPipeline.export_to_obj`] function that you can use to save mesh outputs as `obj` files. + +```python +import torch + +from diffusers import DiffusionPipeline +from diffusers.utils import export_to_ply + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +repo = "openai/shap-e" +pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16, variant="fp16") +pipe = pipe.to(device) + +guidance_scale = 15.0 +prompt = "A birthday cupcake" + +images = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, frame_size=256, output_type="mesh").images + +ply_path = export_to_ply(images[0], "3d_cake.ply") +print(f"saved to folder: {ply_path}") +``` + +Huggingface Datasets supports mesh visualization for mesh files in `glb` format. Below we will show you how to convert your mesh file into `glb` format so that you can use the Dataset viewer to render 3D objects. + +We need to install `trimesh` library. + +``` +pip install trimesh +``` + +To convert the mesh file into `glb` format, + +```python +import trimesh + +mesh = trimesh.load("3d_cake.ply") +mesh.export("3d_cake.glb", file_type="glb") +``` + +By default, the mesh output of Shap-E is from the bottom viewpoint; you can change the default viewpoint by applying a rotation transformation + +```python +import trimesh +import numpy as np + +mesh = trimesh.load("3d_cake.ply") +rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0]) +mesh = mesh.apply_transform(rot) +mesh.export("3d_cake.glb", file_type="glb") +``` + +Now you can upload your mesh file to your dataset and visualize it! Here is the link to the 3D cake we just generated +https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/shap_e/3d_cake.glb + ## ShapEPipeline [[autodoc]] ShapEPipeline - all diff --git a/scripts/convert_shap_e_to_diffusers.py b/scripts/convert_shap_e_to_diffusers.py index d92db176f4..cacd2f7ba3 100644 --- a/scripts/convert_shap_e_to_diffusers.py +++ b/scripts/convert_shap_e_to_diffusers.py @@ -22,7 +22,7 @@ $ python scripts/convert_shap_e_to_diffusers.py \ --prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \ --prior_image_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/image_cond.pt \ --transmitter_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\ - --dump_path /home/yiyi_huggingface_co/model_repo/shap-e/renderer\ + --dump_path /home/yiyi_huggingface_co/model_repo/shap-e-img2img/shap_e_renderer\ --debug renderer ``` """ @@ -373,6 +373,487 @@ def prior_image_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): # renderer +## create the lookup table for marching cubes method used in MeshDecoder + +MC_TABLE = [ + [], + [[0, 1, 0, 2, 0, 4]], + [[1, 0, 1, 5, 1, 3]], + [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2]], + [[2, 0, 2, 3, 2, 6]], + [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4]], + [[1, 0, 1, 5, 1, 3], [2, 6, 0, 2, 3, 2]], + [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4]], + [[3, 1, 3, 7, 3, 2]], + [[0, 2, 0, 4, 0, 1], [3, 7, 2, 3, 1, 3]], + [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0]], + [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5]], + [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6]], + [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6]], + [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7]], + [[0, 4, 1, 5, 3, 7], [0, 4, 3, 7, 2, 6]], + [[4, 0, 4, 6, 4, 5]], + [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1]], + [[1, 5, 1, 3, 1, 0], [4, 6, 5, 4, 0, 4]], + [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2]], + [[2, 0, 2, 3, 2, 6], [4, 5, 0, 4, 6, 4]], + [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1]], + [[2, 6, 2, 0, 3, 2], [1, 0, 1, 5, 3, 1], [6, 4, 5, 4, 0, 4]], + [[1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6]], + [[3, 1, 3, 7, 3, 2], [6, 4, 5, 4, 0, 4]], + [[4, 5, 0, 1, 4, 6], [0, 1, 0, 2, 4, 6], [7, 3, 2, 3, 1, 3]], + [[3, 2, 1, 0, 3, 7], [1, 0, 1, 5, 3, 7], [6, 4, 5, 4, 0, 4]], + [[3, 7, 3, 2, 1, 5], [3, 2, 6, 4, 1, 5], [1, 5, 6, 4, 5, 4], [3, 2, 2, 0, 6, 4]], + [[3, 7, 2, 6, 3, 1], [2, 6, 2, 0, 3, 1], [5, 4, 0, 4, 6, 4]], + [[1, 0, 1, 3, 5, 4], [1, 3, 2, 6, 5, 4], [1, 3, 3, 7, 2, 6], [5, 4, 2, 6, 4, 6]], + [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7], [4, 5, 0, 4, 4, 6]], + [[6, 2, 4, 6, 4, 5], [4, 5, 5, 1, 6, 2], [6, 2, 5, 1, 7, 3]], + [[5, 1, 5, 4, 5, 7]], + [[0, 1, 0, 2, 0, 4], [5, 7, 1, 5, 4, 5]], + [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3]], + [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3]], + [[2, 0, 2, 3, 2, 6], [7, 5, 1, 5, 4, 5]], + [[2, 6, 0, 4, 2, 3], [0, 4, 0, 1, 2, 3], [7, 5, 1, 5, 4, 5]], + [[5, 7, 1, 3, 5, 4], [1, 3, 1, 0, 5, 4], [6, 2, 0, 2, 3, 2]], + [[3, 1, 3, 2, 7, 5], [3, 2, 0, 4, 7, 5], [3, 2, 2, 6, 0, 4], [7, 5, 0, 4, 5, 4]], + [[3, 7, 3, 2, 3, 1], [5, 4, 7, 5, 1, 5]], + [[0, 4, 0, 1, 2, 0], [3, 1, 3, 7, 2, 3], [4, 5, 7, 5, 1, 5]], + [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0]], + [[0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7]], + [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6], [4, 5, 7, 5, 1, 5]], + [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6], [5, 7, 1, 5, 5, 4]], + [[2, 6, 2, 0, 3, 7], [2, 0, 4, 5, 3, 7], [3, 7, 4, 5, 7, 5], [2, 0, 0, 1, 4, 5]], + [[4, 0, 5, 4, 5, 7], [5, 7, 7, 3, 4, 0], [4, 0, 7, 3, 6, 2]], + [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0]], + [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6]], + [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7]], + [[0, 2, 4, 6, 5, 7], [0, 2, 5, 7, 1, 3]], + [[5, 1, 4, 0, 5, 7], [4, 0, 4, 6, 5, 7], [3, 2, 6, 2, 0, 2]], + [[2, 3, 2, 6, 0, 1], [2, 6, 7, 5, 0, 1], [0, 1, 7, 5, 1, 5], [2, 6, 6, 4, 7, 5]], + [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7], [2, 6, 0, 2, 2, 3]], + [[3, 1, 2, 3, 2, 6], [2, 6, 6, 4, 3, 1], [3, 1, 6, 4, 7, 5]], + [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0], [2, 3, 1, 3, 7, 3]], + [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6], [3, 2, 1, 3, 3, 7]], + [[0, 1, 0, 4, 2, 3], [0, 4, 5, 7, 2, 3], [0, 4, 4, 6, 5, 7], [2, 3, 5, 7, 3, 7]], + [[7, 5, 3, 7, 3, 2], [3, 2, 2, 0, 7, 5], [7, 5, 2, 0, 6, 4]], + [[0, 4, 4, 6, 5, 7], [0, 4, 5, 7, 1, 5], [0, 2, 1, 3, 3, 7], [3, 7, 2, 6, 0, 2]], + [ + [3, 1, 7, 3, 6, 2], + [6, 2, 0, 1, 3, 1], + [6, 4, 0, 1, 6, 2], + [6, 4, 5, 1, 0, 1], + [6, 4, 7, 5, 5, 1], + ], + [ + [4, 0, 6, 4, 7, 5], + [7, 5, 1, 0, 4, 0], + [7, 3, 1, 0, 7, 5], + [7, 3, 2, 0, 1, 0], + [7, 3, 6, 2, 2, 0], + ], + [[7, 3, 6, 2, 6, 4], [7, 5, 7, 3, 6, 4]], + [[6, 2, 6, 7, 6, 4]], + [[0, 4, 0, 1, 0, 2], [6, 7, 4, 6, 2, 6]], + [[1, 0, 1, 5, 1, 3], [7, 6, 4, 6, 2, 6]], + [[1, 3, 0, 2, 1, 5], [0, 2, 0, 4, 1, 5], [7, 6, 4, 6, 2, 6]], + [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0]], + [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3]], + [[6, 4, 2, 0, 6, 7], [2, 0, 2, 3, 6, 7], [5, 1, 3, 1, 0, 1]], + [[1, 5, 1, 3, 0, 4], [1, 3, 7, 6, 0, 4], [0, 4, 7, 6, 4, 6], [1, 3, 3, 2, 7, 6]], + [[3, 2, 3, 1, 3, 7], [6, 4, 2, 6, 7, 6]], + [[3, 7, 3, 2, 1, 3], [0, 2, 0, 4, 1, 0], [7, 6, 4, 6, 2, 6]], + [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0], [4, 6, 2, 6, 7, 6]], + [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5], [6, 4, 2, 6, 6, 7]], + [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0]], + [[0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7]], + [[0, 2, 0, 1, 4, 6], [0, 1, 3, 7, 4, 6], [0, 1, 1, 5, 3, 7], [4, 6, 3, 7, 6, 7]], + [[7, 3, 6, 7, 6, 4], [6, 4, 4, 0, 7, 3], [7, 3, 4, 0, 5, 1]], + [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5]], + [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5]], + [[6, 7, 4, 5, 6, 2], [4, 5, 4, 0, 6, 2], [3, 1, 0, 1, 5, 1]], + [[2, 0, 2, 6, 3, 1], [2, 6, 4, 5, 3, 1], [2, 6, 6, 7, 4, 5], [3, 1, 4, 5, 1, 5]], + [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7]], + [[0, 1, 2, 3, 6, 7], [0, 1, 6, 7, 4, 5]], + [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7], [1, 3, 0, 1, 1, 5]], + [[5, 4, 1, 5, 1, 3], [1, 3, 3, 2, 5, 4], [5, 4, 3, 2, 7, 6]], + [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5], [1, 3, 7, 3, 2, 3]], + [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5], [3, 7, 2, 3, 3, 1]], + [[0, 1, 1, 5, 3, 7], [0, 1, 3, 7, 2, 3], [0, 4, 2, 6, 6, 7], [6, 7, 4, 5, 0, 4]], + [ + [6, 2, 7, 6, 5, 4], + [5, 4, 0, 2, 6, 2], + [5, 1, 0, 2, 5, 4], + [5, 1, 3, 2, 0, 2], + [5, 1, 7, 3, 3, 2], + ], + [[3, 1, 3, 7, 2, 0], [3, 7, 5, 4, 2, 0], [2, 0, 5, 4, 0, 4], [3, 7, 7, 6, 5, 4]], + [[1, 0, 3, 1, 3, 7], [3, 7, 7, 6, 1, 0], [1, 0, 7, 6, 5, 4]], + [ + [1, 0, 5, 1, 7, 3], + [7, 3, 2, 0, 1, 0], + [7, 6, 2, 0, 7, 3], + [7, 6, 4, 0, 2, 0], + [7, 6, 5, 4, 4, 0], + ], + [[7, 6, 5, 4, 5, 1], [7, 3, 7, 6, 5, 1]], + [[5, 7, 5, 1, 5, 4], [6, 2, 7, 6, 4, 6]], + [[0, 2, 0, 4, 1, 0], [5, 4, 5, 7, 1, 5], [2, 6, 7, 6, 4, 6]], + [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3], [2, 6, 7, 6, 4, 6]], + [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3], [6, 7, 4, 6, 6, 2]], + [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0], [1, 5, 4, 5, 7, 5]], + [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3], [5, 1, 4, 5, 5, 7]], + [[0, 2, 2, 3, 6, 7], [0, 2, 6, 7, 4, 6], [0, 1, 4, 5, 5, 7], [5, 7, 1, 3, 0, 1]], + [ + [5, 4, 7, 5, 3, 1], + [3, 1, 0, 4, 5, 4], + [3, 2, 0, 4, 3, 1], + [3, 2, 6, 4, 0, 4], + [3, 2, 7, 6, 6, 4], + ], + [[5, 4, 5, 7, 1, 5], [3, 7, 3, 2, 1, 3], [4, 6, 2, 6, 7, 6]], + [[1, 0, 0, 2, 0, 4], [1, 5, 5, 4, 5, 7], [3, 2, 1, 3, 3, 7], [2, 6, 7, 6, 4, 6]], + [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0], [6, 2, 7, 6, 6, 4]], + [ + [0, 4, 2, 3, 0, 2], + [0, 4, 3, 7, 2, 3], + [0, 4, 4, 5, 3, 7], + [4, 5, 5, 7, 3, 7], + [6, 7, 4, 6, 2, 6], + ], + [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0], [5, 4, 7, 5, 5, 1]], + [ + [0, 1, 4, 6, 0, 4], + [0, 1, 6, 7, 4, 6], + [0, 1, 1, 3, 6, 7], + [1, 3, 3, 7, 6, 7], + [5, 7, 1, 5, 4, 5], + ], + [ + [6, 7, 4, 6, 0, 2], + [0, 2, 3, 7, 6, 7], + [0, 1, 3, 7, 0, 2], + [0, 1, 5, 7, 3, 7], + [0, 1, 4, 5, 5, 7], + ], + [[4, 0, 6, 7, 4, 6], [4, 0, 7, 3, 6, 7], [4, 0, 5, 7, 7, 3], [4, 5, 5, 7, 4, 0]], + [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0]], + [[0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7]], + [[1, 3, 1, 0, 5, 7], [1, 0, 2, 6, 5, 7], [5, 7, 2, 6, 7, 6], [1, 0, 0, 4, 2, 6]], + [[2, 0, 6, 2, 6, 7], [6, 7, 7, 5, 2, 0], [2, 0, 7, 5, 3, 1]], + [[0, 4, 0, 2, 1, 5], [0, 2, 6, 7, 1, 5], [0, 2, 2, 3, 6, 7], [1, 5, 6, 7, 5, 7]], + [[7, 6, 5, 7, 5, 1], [5, 1, 1, 0, 7, 6], [7, 6, 1, 0, 3, 2]], + [ + [2, 0, 3, 2, 7, 6], + [7, 6, 4, 0, 2, 0], + [7, 5, 4, 0, 7, 6], + [7, 5, 1, 0, 4, 0], + [7, 5, 3, 1, 1, 0], + ], + [[7, 5, 3, 1, 3, 2], [7, 6, 7, 5, 3, 2]], + [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0], [3, 1, 7, 3, 3, 2]], + [ + [0, 2, 1, 5, 0, 1], + [0, 2, 5, 7, 1, 5], + [0, 2, 2, 6, 5, 7], + [2, 6, 6, 7, 5, 7], + [3, 7, 2, 3, 1, 3], + ], + [ + [3, 7, 2, 3, 0, 1], + [0, 1, 5, 7, 3, 7], + [0, 4, 5, 7, 0, 1], + [0, 4, 6, 7, 5, 7], + [0, 4, 2, 6, 6, 7], + ], + [[2, 0, 3, 7, 2, 3], [2, 0, 7, 5, 3, 7], [2, 0, 6, 7, 7, 5], [2, 6, 6, 7, 2, 0]], + [ + [5, 7, 1, 5, 0, 4], + [0, 4, 6, 7, 5, 7], + [0, 2, 6, 7, 0, 4], + [0, 2, 3, 7, 6, 7], + [0, 2, 1, 3, 3, 7], + ], + [[1, 0, 5, 7, 1, 5], [1, 0, 7, 6, 5, 7], [1, 0, 3, 7, 7, 6], [1, 3, 3, 7, 1, 0]], + [[0, 2, 0, 1, 0, 4], [3, 7, 6, 7, 5, 7]], + [[7, 5, 7, 3, 7, 6]], + [[7, 3, 7, 5, 7, 6]], + [[0, 1, 0, 2, 0, 4], [6, 7, 3, 7, 5, 7]], + [[1, 3, 1, 0, 1, 5], [7, 6, 3, 7, 5, 7]], + [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2], [6, 7, 3, 7, 5, 7]], + [[2, 6, 2, 0, 2, 3], [7, 5, 6, 7, 3, 7]], + [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4], [5, 7, 6, 7, 3, 7]], + [[1, 5, 1, 3, 0, 1], [2, 3, 2, 6, 0, 2], [5, 7, 6, 7, 3, 7]], + [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4], [7, 6, 3, 7, 7, 5]], + [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2]], + [[7, 6, 3, 2, 7, 5], [3, 2, 3, 1, 7, 5], [4, 0, 1, 0, 2, 0]], + [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2]], + [[2, 3, 2, 0, 6, 7], [2, 0, 1, 5, 6, 7], [2, 0, 0, 4, 1, 5], [6, 7, 1, 5, 7, 5]], + [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1]], + [[0, 4, 0, 1, 2, 6], [0, 1, 5, 7, 2, 6], [2, 6, 5, 7, 6, 7], [0, 1, 1, 3, 5, 7]], + [[1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6]], + [[5, 1, 7, 5, 7, 6], [7, 6, 6, 2, 5, 1], [5, 1, 6, 2, 4, 0]], + [[4, 5, 4, 0, 4, 6], [7, 3, 5, 7, 6, 7]], + [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1], [3, 7, 5, 7, 6, 7]], + [[4, 6, 4, 5, 0, 4], [1, 5, 1, 3, 0, 1], [6, 7, 3, 7, 5, 7]], + [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2], [7, 3, 5, 7, 7, 6]], + [[2, 3, 2, 6, 0, 2], [4, 6, 4, 5, 0, 4], [3, 7, 5, 7, 6, 7]], + [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1], [7, 5, 6, 7, 7, 3]], + [[0, 1, 1, 5, 1, 3], [0, 2, 2, 3, 2, 6], [4, 5, 0, 4, 4, 6], [5, 7, 6, 7, 3, 7]], + [ + [1, 3, 5, 4, 1, 5], + [1, 3, 4, 6, 5, 4], + [1, 3, 3, 2, 4, 6], + [3, 2, 2, 6, 4, 6], + [7, 6, 3, 7, 5, 7], + ], + [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2], [0, 4, 6, 4, 5, 4]], + [[1, 0, 0, 2, 4, 6], [1, 0, 4, 6, 5, 4], [1, 3, 5, 7, 7, 6], [7, 6, 3, 2, 1, 3]], + [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2], [4, 6, 5, 4, 4, 0]], + [ + [7, 5, 6, 7, 2, 3], + [2, 3, 1, 5, 7, 5], + [2, 0, 1, 5, 2, 3], + [2, 0, 4, 5, 1, 5], + [2, 0, 6, 4, 4, 5], + ], + [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1], [4, 0, 6, 4, 4, 5]], + [ + [4, 6, 5, 4, 1, 0], + [1, 0, 2, 6, 4, 6], + [1, 3, 2, 6, 1, 0], + [1, 3, 7, 6, 2, 6], + [1, 3, 5, 7, 7, 6], + ], + [ + [1, 5, 0, 2, 1, 0], + [1, 5, 2, 6, 0, 2], + [1, 5, 5, 7, 2, 6], + [5, 7, 7, 6, 2, 6], + [4, 6, 5, 4, 0, 4], + ], + [[5, 1, 4, 6, 5, 4], [5, 1, 6, 2, 4, 6], [5, 1, 7, 6, 6, 2], [5, 7, 7, 6, 5, 1]], + [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1]], + [[7, 3, 5, 1, 7, 6], [5, 1, 5, 4, 7, 6], [2, 0, 4, 0, 1, 0]], + [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4]], + [[0, 2, 0, 4, 1, 3], [0, 4, 6, 7, 1, 3], [1, 3, 6, 7, 3, 7], [0, 4, 4, 5, 6, 7]], + [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1], [0, 2, 3, 2, 6, 2]], + [[1, 5, 5, 4, 7, 6], [1, 5, 7, 6, 3, 7], [1, 0, 3, 2, 2, 6], [2, 6, 0, 4, 1, 0]], + [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4], [2, 0, 3, 2, 2, 6]], + [ + [2, 3, 6, 2, 4, 0], + [4, 0, 1, 3, 2, 3], + [4, 5, 1, 3, 4, 0], + [4, 5, 7, 3, 1, 3], + [4, 5, 6, 7, 7, 3], + ], + [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6]], + [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6], [0, 4, 1, 0, 0, 2]], + [[1, 0, 5, 4, 7, 6], [1, 0, 7, 6, 3, 2]], + [[2, 3, 0, 2, 0, 4], [0, 4, 4, 5, 2, 3], [2, 3, 4, 5, 6, 7]], + [[1, 3, 1, 5, 0, 2], [1, 5, 7, 6, 0, 2], [1, 5, 5, 4, 7, 6], [0, 2, 7, 6, 2, 6]], + [ + [5, 1, 4, 5, 6, 7], + [6, 7, 3, 1, 5, 1], + [6, 2, 3, 1, 6, 7], + [6, 2, 0, 1, 3, 1], + [6, 2, 4, 0, 0, 1], + ], + [[6, 7, 2, 6, 2, 0], [2, 0, 0, 1, 6, 7], [6, 7, 0, 1, 4, 5]], + [[6, 2, 4, 0, 4, 5], [6, 7, 6, 2, 4, 5]], + [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1]], + [[1, 5, 1, 0, 3, 7], [1, 0, 4, 6, 3, 7], [1, 0, 0, 2, 4, 6], [3, 7, 4, 6, 7, 6]], + [[1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6]], + [[6, 4, 7, 6, 7, 3], [7, 3, 3, 1, 6, 4], [6, 4, 3, 1, 2, 0]], + [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1], [2, 3, 6, 2, 2, 0]], + [ + [7, 6, 3, 7, 1, 5], + [1, 5, 4, 6, 7, 6], + [1, 0, 4, 6, 1, 5], + [1, 0, 2, 6, 4, 6], + [1, 0, 3, 2, 2, 6], + ], + [ + [1, 0, 3, 7, 1, 3], + [1, 0, 7, 6, 3, 7], + [1, 0, 0, 4, 7, 6], + [0, 4, 4, 6, 7, 6], + [2, 6, 0, 2, 3, 2], + ], + [[3, 1, 7, 6, 3, 7], [3, 1, 6, 4, 7, 6], [3, 1, 2, 6, 6, 4], [3, 2, 2, 6, 3, 1]], + [[3, 2, 3, 1, 7, 6], [3, 1, 0, 4, 7, 6], [7, 6, 0, 4, 6, 4], [3, 1, 1, 5, 0, 4]], + [ + [0, 1, 2, 0, 6, 4], + [6, 4, 5, 1, 0, 1], + [6, 7, 5, 1, 6, 4], + [6, 7, 3, 1, 5, 1], + [6, 7, 2, 3, 3, 1], + ], + [[0, 1, 4, 0, 4, 6], [4, 6, 6, 7, 0, 1], [0, 1, 6, 7, 2, 3]], + [[6, 7, 2, 3, 2, 0], [6, 4, 6, 7, 2, 0]], + [ + [2, 6, 0, 2, 1, 3], + [1, 3, 7, 6, 2, 6], + [1, 5, 7, 6, 1, 3], + [1, 5, 4, 6, 7, 6], + [1, 5, 0, 4, 4, 6], + ], + [[1, 5, 1, 0, 1, 3], [4, 6, 7, 6, 2, 6]], + [[0, 1, 2, 6, 0, 2], [0, 1, 6, 7, 2, 6], [0, 1, 4, 6, 6, 7], [0, 4, 4, 6, 0, 1]], + [[6, 7, 6, 2, 6, 4]], + [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4]], + [[7, 5, 6, 4, 7, 3], [6, 4, 6, 2, 7, 3], [1, 0, 2, 0, 4, 0]], + [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4], [0, 1, 5, 1, 3, 1]], + [[2, 0, 0, 4, 1, 5], [2, 0, 1, 5, 3, 1], [2, 6, 3, 7, 7, 5], [7, 5, 6, 4, 2, 6]], + [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4]], + [[3, 2, 3, 7, 1, 0], [3, 7, 6, 4, 1, 0], [3, 7, 7, 5, 6, 4], [1, 0, 6, 4, 0, 4]], + [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4], [1, 5, 3, 1, 1, 0]], + [ + [7, 3, 5, 7, 4, 6], + [4, 6, 2, 3, 7, 3], + [4, 0, 2, 3, 4, 6], + [4, 0, 1, 3, 2, 3], + [4, 0, 5, 1, 1, 3], + ], + [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5]], + [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5], [0, 1, 2, 0, 0, 4]], + [[1, 0, 1, 5, 3, 2], [1, 5, 4, 6, 3, 2], [3, 2, 4, 6, 2, 6], [1, 5, 5, 7, 4, 6]], + [ + [0, 2, 4, 0, 5, 1], + [5, 1, 3, 2, 0, 2], + [5, 7, 3, 2, 5, 1], + [5, 7, 6, 2, 3, 2], + [5, 7, 4, 6, 6, 2], + ], + [[2, 0, 3, 1, 7, 5], [2, 0, 7, 5, 6, 4]], + [[4, 6, 0, 4, 0, 1], [0, 1, 1, 3, 4, 6], [4, 6, 1, 3, 5, 7]], + [[0, 2, 1, 0, 1, 5], [1, 5, 5, 7, 0, 2], [0, 2, 5, 7, 4, 6]], + [[5, 7, 4, 6, 4, 0], [5, 1, 5, 7, 4, 0]], + [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2]], + [[0, 1, 0, 2, 4, 5], [0, 2, 3, 7, 4, 5], [4, 5, 3, 7, 5, 7], [0, 2, 2, 6, 3, 7]], + [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2], [1, 0, 5, 1, 1, 3]], + [ + [1, 5, 3, 1, 2, 0], + [2, 0, 4, 5, 1, 5], + [2, 6, 4, 5, 2, 0], + [2, 6, 7, 5, 4, 5], + [2, 6, 3, 7, 7, 5], + ], + [[2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5]], + [[3, 2, 7, 3, 7, 5], [7, 5, 5, 4, 3, 2], [3, 2, 5, 4, 1, 0]], + [ + [2, 3, 0, 4, 2, 0], + [2, 3, 4, 5, 0, 4], + [2, 3, 3, 7, 4, 5], + [3, 7, 7, 5, 4, 5], + [1, 5, 3, 1, 0, 1], + ], + [[3, 2, 1, 5, 3, 1], [3, 2, 5, 4, 1, 5], [3, 2, 7, 5, 5, 4], [3, 7, 7, 5, 3, 2]], + [[2, 6, 2, 3, 0, 4], [2, 3, 7, 5, 0, 4], [2, 3, 3, 1, 7, 5], [0, 4, 7, 5, 4, 5]], + [ + [3, 2, 1, 3, 5, 7], + [5, 7, 6, 2, 3, 2], + [5, 4, 6, 2, 5, 7], + [5, 4, 0, 2, 6, 2], + [5, 4, 1, 0, 0, 2], + ], + [ + [4, 5, 0, 4, 2, 6], + [2, 6, 7, 5, 4, 5], + [2, 3, 7, 5, 2, 6], + [2, 3, 1, 5, 7, 5], + [2, 3, 0, 1, 1, 5], + ], + [[2, 3, 2, 0, 2, 6], [1, 5, 7, 5, 4, 5]], + [[5, 7, 4, 5, 4, 0], [4, 0, 0, 2, 5, 7], [5, 7, 0, 2, 1, 3]], + [[5, 4, 1, 0, 1, 3], [5, 7, 5, 4, 1, 3]], + [[0, 2, 4, 5, 0, 4], [0, 2, 5, 7, 4, 5], [0, 2, 1, 5, 5, 7], [0, 1, 1, 5, 0, 2]], + [[5, 4, 5, 1, 5, 7]], + [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3]], + [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3], [0, 2, 4, 0, 0, 1]], + [[3, 7, 3, 1, 2, 6], [3, 1, 5, 4, 2, 6], [3, 1, 1, 0, 5, 4], [2, 6, 5, 4, 6, 4]], + [ + [6, 4, 2, 6, 3, 7], + [3, 7, 5, 4, 6, 4], + [3, 1, 5, 4, 3, 7], + [3, 1, 0, 4, 5, 4], + [3, 1, 2, 0, 0, 4], + ], + [[2, 0, 2, 3, 6, 4], [2, 3, 1, 5, 6, 4], [6, 4, 1, 5, 4, 5], [2, 3, 3, 7, 1, 5]], + [ + [0, 4, 1, 0, 3, 2], + [3, 2, 6, 4, 0, 4], + [3, 7, 6, 4, 3, 2], + [3, 7, 5, 4, 6, 4], + [3, 7, 1, 5, 5, 4], + ], + [ + [1, 3, 0, 1, 4, 5], + [4, 5, 7, 3, 1, 3], + [4, 6, 7, 3, 4, 5], + [4, 6, 2, 3, 7, 3], + [4, 6, 0, 2, 2, 3], + ], + [[3, 7, 3, 1, 3, 2], [5, 4, 6, 4, 0, 4]], + [[3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4]], + [ + [3, 1, 2, 6, 3, 2], + [3, 1, 6, 4, 2, 6], + [3, 1, 1, 5, 6, 4], + [1, 5, 5, 4, 6, 4], + [0, 4, 1, 0, 2, 0], + ], + [[4, 5, 6, 4, 6, 2], [6, 2, 2, 3, 4, 5], [4, 5, 2, 3, 0, 1]], + [[2, 3, 6, 4, 2, 6], [2, 3, 4, 5, 6, 4], [2, 3, 0, 4, 4, 5], [2, 0, 0, 4, 2, 3]], + [[1, 3, 5, 1, 5, 4], [5, 4, 4, 6, 1, 3], [1, 3, 4, 6, 0, 2]], + [[1, 3, 0, 4, 1, 0], [1, 3, 4, 6, 0, 4], [1, 3, 5, 4, 4, 6], [1, 5, 5, 4, 1, 3]], + [[4, 6, 0, 2, 0, 1], [4, 5, 4, 6, 0, 1]], + [[4, 6, 4, 0, 4, 5]], + [[4, 0, 6, 2, 7, 3], [4, 0, 7, 3, 5, 1]], + [[1, 5, 0, 1, 0, 2], [0, 2, 2, 6, 1, 5], [1, 5, 2, 6, 3, 7]], + [[3, 7, 1, 3, 1, 0], [1, 0, 0, 4, 3, 7], [3, 7, 0, 4, 2, 6]], + [[3, 1, 2, 0, 2, 6], [3, 7, 3, 1, 2, 6]], + [[0, 4, 2, 0, 2, 3], [2, 3, 3, 7, 0, 4], [0, 4, 3, 7, 1, 5]], + [[3, 7, 1, 5, 1, 0], [3, 2, 3, 7, 1, 0]], + [[0, 4, 1, 3, 0, 1], [0, 4, 3, 7, 1, 3], [0, 4, 2, 3, 3, 7], [0, 2, 2, 3, 0, 4]], + [[3, 7, 3, 1, 3, 2]], + [[2, 6, 3, 2, 3, 1], [3, 1, 1, 5, 2, 6], [2, 6, 1, 5, 0, 4]], + [[1, 5, 3, 2, 1, 3], [1, 5, 2, 6, 3, 2], [1, 5, 0, 2, 2, 6], [1, 0, 0, 2, 1, 5]], + [[2, 3, 0, 1, 0, 4], [2, 6, 2, 3, 0, 4]], + [[2, 3, 2, 0, 2, 6]], + [[1, 5, 0, 4, 0, 2], [1, 3, 1, 5, 0, 2]], + [[1, 5, 1, 0, 1, 3]], + [[0, 2, 0, 1, 0, 4]], + [], +] + + +def create_mc_lookup_table(): + cases = torch.zeros(256, 5, 3, dtype=torch.long) + masks = torch.zeros(256, 5, dtype=torch.bool) + + edge_to_index = { + (0, 1): 0, + (2, 3): 1, + (4, 5): 2, + (6, 7): 3, + (0, 2): 4, + (1, 3): 5, + (4, 6): 6, + (5, 7): 7, + (0, 4): 8, + (1, 5): 9, + (2, 6): 10, + (3, 7): 11, + } + + for i, case in enumerate(MC_TABLE): + for j, tri in enumerate(case): + for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])): + cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)] + masks[i, j] = True + return cases, masks + + RENDERER_CONFIG = {} @@ -400,7 +881,12 @@ def renderer_model_original_checkpoint_to_diffusers_checkpoint(model, checkpoint } ) - diffusers_checkpoint.update({"void.background": torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)}) + diffusers_checkpoint.update({"void.background": model.state_dict()["void.background"]}) + + cases, masks = create_mc_lookup_table() + + diffusers_checkpoint.update({"mesh_decoder.cases": cases}) + diffusers_checkpoint.update({"mesh_decoder.masks": masks}) return diffusers_checkpoint diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py index fdcbe55086..d93047ec66 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -95,7 +95,7 @@ class ShapEPipeline(DiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`HeunDiscreteScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. - renderer ([`ShapERenderer`]): + shap_e_renderer ([`ShapERenderer`]): Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects with the NeRF rendering method """ @@ -106,7 +106,7 @@ class ShapEPipeline(DiffusionPipeline): text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: HeunDiscreteScheduler, - renderer: ShapERenderer, + shap_e_renderer: ShapERenderer, ): super().__init__() @@ -115,7 +115,7 @@ class ShapEPipeline(DiffusionPipeline): text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, - renderer=renderer, + shap_e_renderer=shap_e_renderer, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents @@ -149,7 +149,7 @@ class ShapEPipeline(DiffusionPipeline): torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None - for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]: + for cpu_offloaded_model in [self.text_encoder, self.prior, self.shap_e_renderer]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: @@ -218,7 +218,7 @@ class ShapEPipeline(DiffusionPipeline): latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, frame_size: int = 64, - output_type: Optional[str] = "pil", # pil, np, latent + output_type: Optional[str] = "pil", # pil, np, latent, mesh return_dict: bool = True, ): """ @@ -248,8 +248,8 @@ class ShapEPipeline(DiffusionPipeline): frame_size (`int`, *optional*, default to 64): the width and height of each image frame of the generated 3d output output_type (`str`, *optional*, defaults to `"pt"`): - The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` - (`torch.Tensor`). + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. @@ -319,30 +319,39 @@ class ShapEPipeline(DiffusionPipeline): sample=latents, ).prev_sample + if output_type not in ["np", "pil", "latent", "mesh"]: + raise ValueError( + f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" + ) + if output_type == "latent": return ShapEPipelineOutput(images=latents) images = [] - for i, latent in enumerate(latents): - image = self.renderer.decode( - latent[None, :], - device, - size=frame_size, - ray_batch_size=4096, - n_coarse_samples=64, - n_fine_samples=128, - ) - images.append(image) + if output_type == "mesh": + for i, latent in enumerate(latents): + mesh = self.shap_e_renderer.decode_to_mesh( + latent[None, :], + device, + ) + images.append(mesh) - images = torch.stack(images) + else: + # np, pil + for i, latent in enumerate(latents): + image = self.shap_e_renderer.decode_to_image( + latent[None, :], + device, + size=frame_size, + ) + images.append(image) - if output_type not in ["np", "pil"]: - raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") + images = torch.stack(images) - images = images.cpu().numpy() + images = images.cpu().numpy() - if output_type == "pil": - images = [self.numpy_to_pil(image) for image in images] + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py index 08c585c5ad..1144656029 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -94,7 +94,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`HeunDiscreteScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. - renderer ([`ShapERenderer`]): + shap_e_renderer ([`ShapERenderer`]): Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects with the NeRF rendering method """ @@ -105,7 +105,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, scheduler: HeunDiscreteScheduler, - renderer: ShapERenderer, + shap_e_renderer: ShapERenderer, ): super().__init__() @@ -114,7 +114,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): image_encoder=image_encoder, image_processor=image_processor, scheduler=scheduler, - renderer=renderer, + shap_e_renderer=shap_e_renderer, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents @@ -170,7 +170,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, frame_size: int = 64, - output_type: Optional[str] = "pil", # pil, np, latent + output_type: Optional[str] = "pil", # pil, np, latent, mesh return_dict: bool = True, ): """ @@ -200,8 +200,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): frame_size (`int`, *optional*, default to 64): the width and height of each image frame of the generated 3d output output_type (`str`, *optional*, defaults to `"pt"`): - The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` - (`torch.Tensor`). + (`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. @@ -275,32 +274,39 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): sample=latents, ).prev_sample + if output_type not in ["np", "pil", "latent", "mesh"]: + raise ValueError( + f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" + ) + if output_type == "latent": return ShapEPipelineOutput(images=latents) images = [] - for i, latent in enumerate(latents): - print() - image = self.renderer.decode( - latent[None, :], - device, - size=frame_size, - ray_batch_size=4096, - n_coarse_samples=64, - n_fine_samples=128, - ) + if output_type == "mesh": + for i, latent in enumerate(latents): + mesh = self.shap_e_renderer.decode_to_mesh( + latent[None, :], + device, + ) + images.append(mesh) - images.append(image) + else: + # np, pil + for i, latent in enumerate(latents): + image = self.shap_e_renderer.decode_to_image( + latent[None, :], + device, + size=frame_size, + ) + images.append(image) - images = torch.stack(images) + images = torch.stack(images) - if output_type not in ["np", "pil"]: - raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") + images = images.cpu().numpy() - images = images.cpu().numpy() - - if output_type == "pil": - images = [self.numpy_to_pil(image) for image in images] + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py index 8b075e671f..ac5c06042e 100644 --- a/src/diffusers/pipelines/shap_e/renderer.py +++ b/src/diffusers/pipelines/shap_e/renderer.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import numpy as np import torch @@ -116,6 +116,101 @@ def integrate_samples(volume_range, ts, density, channels): return channels, weights, transmittance +def volume_query_points(volume, grid_size): + indices = torch.arange(grid_size**3, device=volume.bbox_min.device) + zs = indices % grid_size + ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size + xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size + combined = torch.stack([xs, ys, zs], dim=1) + return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min + + +def _convert_srgb_to_linear(u: torch.Tensor): + return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) + + +def _create_flat_edge_indices( + flat_cube_indices: torch.Tensor, + grid_size: Tuple[int, int, int], +): + num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] + y_offset = num_xs + num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] + z_offset = num_xs + num_ys + return torch.stack( + [ + # Edges spanning x-axis. + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + # Edges spanning y-axis. + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + # Edges spanning z-axis. + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ], + dim=-1, + ) + + class VoidNeRFModel(nn.Module): """ Implements the default empty space model where all queries are rendered as background. @@ -368,6 +463,141 @@ class ImportanceRaySampler(nn.Module): return ts +@dataclass +class MeshDecoderOutput(BaseOutput): + """ + A 3D triangle mesh with optional data at the vertices and faces. + + Args: + verts (`torch.Tensor` of shape `(N, 3)`): + array of vertext coordinates + faces (`torch.Tensor` of shape `(N, 3)`): + array of triangles, pointing to indices in verts. + vertext_channels (Dict): + vertext coordinates for each color channel + """ + + verts: torch.Tensor + faces: torch.Tensor + vertex_channels: Dict[str, torch.Tensor] + + +class MeshDecoder(nn.Module): + """ + Construct meshes from Signed distance functions (SDFs) using marching cubes method + """ + + def __init__(self): + super().__init__() + cases = torch.zeros(256, 5, 3, dtype=torch.long) + masks = torch.zeros(256, 5, dtype=torch.bool) + + self.register_buffer("cases", cases) + self.register_buffer("masks", masks) + + def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor): + """ + For a signed distance field, produce a mesh using marching cubes. + + :param field: a 3D tensor of field values, where negative values correspond + to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively. + :param min_point: a tensor of shape [3] containing the point corresponding + to (0, 0, 0) in the field. + :param size: a tensor of shape [3] containing the per-axis distance from the + (0, 0, 0) field corner and the (-1, -1, -1) field corner. + """ + assert len(field.shape) == 3, "input must be a 3D scalar field" + dev = field.device + + cases = self.cases.to(dev) + masks = self.masks.to(dev) + + min_point = min_point.to(dev) + size = size.to(dev) + + grid_size = field.shape + grid_size_tensor = torch.tensor(grid_size).to(size) + + # Create bitmasks between 0 and 255 (inclusive) indicating the state + # of the eight corners of each cube. + bitmasks = (field > 0).to(torch.uint8) + bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) + bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) + bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) + + # Compute corner coordinates across the entire grid. + corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) + corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[ + :, None, None + ] + corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[ + :, None + ] + corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype) + + # Compute all vertices across all edges in the grid, even though we will + # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. + # These are all midpoints, and don't account for interpolation (which is + # done later based on the used edge midpoints). + edge_midpoints = torch.cat( + [ + ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), + ], + dim=0, + ) + + # Create a flat array of [X, Y, Z] indices for each cube. + cube_indices = torch.zeros( + grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long + ) + cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None] + cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None] + cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) + flat_cube_indices = cube_indices.reshape(-1, 3) + + # Create a flat array mapping each cube to 12 global edge indices. + edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) + + # Apply the LUT to figure out the triangles. + flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask + local_tris = cases[flat_bitmasks] + local_masks = masks[flat_bitmasks] + # Compute the global edge indices for the triangles. + global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape( + local_tris.shape + ) + # Select the used triangles for each cube. + selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] + + # Now we have a bunch of indices into the full list of possible vertices, + # but we want to reduce this list to only the used vertices. + used_vertex_indices = torch.unique(selected_tris.view(-1)) + used_edge_midpoints = edge_midpoints[used_vertex_indices] + old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) + old_index_to_new_index[used_vertex_indices] = torch.arange( + len(used_vertex_indices), device=dev, dtype=torch.long + ) + + # Rewrite the triangles to use the new indices + faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape) + + # Compute the actual interpolated coordinates corresponding to edge midpoints. + v1 = torch.floor(used_edge_midpoints).to(torch.long) + v2 = torch.ceil(used_edge_midpoints).to(torch.long) + s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] + s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] + p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point + p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point + # The signs of s1 and s2 should be different. We want to find + # t such that t*s2 + (1-t)*s1 = 0. + t = (s1 / (s1 - s2))[:, None] + verts = t * p2 + (1 - t) * p1 + + return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None) + + @dataclass class MLPNeRFModelOutput(BaseOutput): density: torch.Tensor @@ -429,7 +659,7 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin): return mapped_output - def forward(self, *, position, direction, ts, nerf_level="coarse"): + def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"): h = encode_position(position) h_preact = h @@ -455,10 +685,17 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin): if nerf_level == "coarse": h_density = activation["density_coarse"] - h_channels = activation["nerf_coarse"] else: h_density = activation["density_fine"] - h_channels = activation["nerf_fine"] + + if rendering_mode == "nerf": + if nerf_level == "coarse": + h_channels = activation["nerf_coarse"] + else: + h_channels = activation["nerf_fine"] + + elif rendering_mode == "stf": + h_channels = activation["stf"] density = self.density_activation(h_density) signed_distance = self.sdf_activation(activation["sdf"]) @@ -583,6 +820,7 @@ class ShapERenderer(ModelMixin, ConfigMixin): self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) self.void = VoidNeRFModel(background=background, channel_scale=255.0) self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) + self.mesh_decoder = MeshDecoder() @torch.no_grad() def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): @@ -664,7 +902,7 @@ class ShapERenderer(ModelMixin, ConfigMixin): return channels, weighted_sampler, model_out @torch.no_grad() - def decode( + def decode_to_image( self, latents, device, @@ -707,3 +945,106 @@ class ShapERenderer(ModelMixin, ConfigMixin): images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) return images + + @torch.no_grad() + def decode_to_mesh( + self, + latents, + device, + grid_size: int = 128, + query_batch_size: int = 4096, + texture_channels: Tuple = ("R", "G", "B"), + ): + # 1. project the the paramters from the generated latents + projected_params = self.params_proj(latents) + + # 2. update the mlp layers of the renderer + for name, param in self.mlp.state_dict().items(): + if f"nerstf.{name}" in projected_params.keys(): + param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) + + # 3. decoding with STF rendering + # 3.1 query the SDF values at vertices along a regular 128**3 grid + + query_points = volume_query_points(self.volume, grid_size) + query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype) + + fields = [] + + for idx in range(0, query_positions.shape[1], query_batch_size): + query_batch = query_positions[:, idx : idx + query_batch_size] + + model_out = self.mlp( + position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" + ) + fields.append(model_out.signed_distance) + + # predicted SDF values + fields = torch.cat(fields, dim=1) + fields = fields.float() + + assert ( + len(fields.shape) == 3 and fields.shape[-1] == 1 + ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" + + fields = fields.reshape(1, *([grid_size] * 3)) + + # create grid 128 x 128 x 128 + # - force a negative border around the SDFs to close off all the models. + full_grid = torch.zeros( + 1, + grid_size + 2, + grid_size + 2, + grid_size + 2, + device=fields.device, + dtype=fields.dtype, + ) + full_grid.fill_(-1.0) + full_grid[:, 1:-1, 1:-1, 1:-1] = fields + fields = full_grid + + # apply a differentiable implementation of Marching Cubes to construct meshs + raw_meshes = [] + mesh_mask = [] + + for field in fields: + raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min) + mesh_mask.append(True) + raw_meshes.append(raw_mesh) + + mesh_mask = torch.tensor(mesh_mask, device=fields.device) + max_vertices = max(len(m.verts) for m in raw_meshes) + + # 3.2. query the texture color head at each vertex of the resulting mesh. + texture_query_positions = torch.stack( + [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], + dim=0, + ) + texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype) + + textures = [] + + for idx in range(0, texture_query_positions.shape[1], query_batch_size): + query_batch = texture_query_positions[:, idx : idx + query_batch_size] + + texture_model_out = self.mlp( + position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" + ) + textures.append(texture_model_out.channels) + + # predict texture color + textures = torch.cat(textures, dim=1) + + textures = _convert_srgb_to_linear(textures) + textures = textures.float() + + # 3.3 augument the mesh with texture data + assert len(textures.shape) == 3 and textures.shape[-1] == len( + texture_channels + ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" + + for m, texture in zip(raw_meshes, textures): + texture = texture[: len(m.verts)] + m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1))) + + return raw_meshes[0] diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 98fac64497..fb54c151b2 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -103,7 +103,7 @@ if is_torch_available(): ) from .torch_utils import maybe_allow_in_graph -from .testing_utils import export_to_gif, export_to_video +from .testing_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video logger = get_logger(__name__) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 64eb3ac925..3976be0fd7 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,12 +1,15 @@ import inspect +import io import logging import multiprocessing import os import random import re +import struct import tempfile import unittest import urllib.parse +from contextlib import contextmanager from distutils.util import strtobool from io import BytesIO, StringIO from pathlib import Path @@ -315,6 +318,85 @@ def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> return output_gif_path +@contextmanager +def buffered_writer(raw_f): + f = io.BufferedWriter(raw_f) + yield f + f.flush() + + +def export_to_ply(mesh, output_ply_path: str = None): + """ + Write a PLY file for a mesh. + """ + if output_ply_path is None: + output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name + + coords = mesh.verts.detach().cpu().numpy() + faces = mesh.faces.cpu().numpy() + rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1) + + with buffered_writer(open(output_ply_path, "wb")) as f: + f.write(b"ply\n") + f.write(b"format binary_little_endian 1.0\n") + f.write(bytes(f"element vertex {len(coords)}\n", "ascii")) + f.write(b"property float x\n") + f.write(b"property float y\n") + f.write(b"property float z\n") + if rgb is not None: + f.write(b"property uchar red\n") + f.write(b"property uchar green\n") + f.write(b"property uchar blue\n") + if faces is not None: + f.write(bytes(f"element face {len(faces)}\n", "ascii")) + f.write(b"property list uchar int vertex_index\n") + f.write(b"end_header\n") + + if rgb is not None: + rgb = (rgb * 255.499).round().astype(int) + vertices = [ + (*coord, *rgb) + for coord, rgb in zip( + coords.tolist(), + rgb.tolist(), + ) + ] + format = struct.Struct("<3f3B") + for item in vertices: + f.write(format.pack(*item)) + else: + format = struct.Struct("<3f") + for vertex in coords.tolist(): + f.write(format.pack(*vertex)) + + if faces is not None: + format = struct.Struct(" str: if is_opencv_available(): import cv2 diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py index d095dd9d49..90ff37de6e 100644 --- a/tests/pipelines/shap_e/test_shap_e.py +++ b/tests/pipelines/shap_e/test_shap_e.py @@ -131,7 +131,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase): prior = self.dummy_prior text_encoder = self.dummy_text_encoder tokenizer = self.dummy_tokenizer - renderer = self.dummy_renderer + shap_e_renderer = self.dummy_renderer scheduler = HeunDiscreteScheduler( beta_schedule="exp", @@ -145,7 +145,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "prior": prior, "text_encoder": text_encoder, "tokenizer": tokenizer, - "renderer": renderer, + "shap_e_renderer": shap_e_renderer, "scheduler": scheduler, } diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py index f6638a994f..0dffac98aa 100644 --- a/tests/pipelines/shap_e/test_shap_e_img2img.py +++ b/tests/pipelines/shap_e/test_shap_e_img2img.py @@ -143,7 +143,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): prior = self.dummy_prior image_encoder = self.dummy_image_encoder image_processor = self.dummy_image_processor - renderer = self.dummy_renderer + shap_e_renderer = self.dummy_renderer scheduler = HeunDiscreteScheduler( beta_schedule="exp", @@ -157,7 +157,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "prior": prior, "image_encoder": image_encoder, "image_processor": image_processor, - "renderer": renderer, + "shap_e_renderer": shap_e_renderer, "scheduler": scheduler, }