From 9324a54c67bc79ab5b572467b7e9546c84726abd Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 18 Jun 2023 19:45:29 +0000 Subject: [PATCH] add --- scripts/convert_shap_e_to_diffusers.py | 62 +++++++- src/diffusers/pipelines/shap_e/__init__.py | 2 + src/diffusers/pipelines/shap_e/camera.py | 142 ++++++++++++++++++ src/diffusers/pipelines/shap_e/params_proj.py | 96 ++++++++++++ src/diffusers/pipelines/shap_e/renderer.py | 103 +++++++++++++ 5 files changed, 402 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/pipelines/shap_e/camera.py create mode 100644 src/diffusers/pipelines/shap_e/params_proj.py create mode 100644 src/diffusers/pipelines/shap_e/renderer.py diff --git a/scripts/convert_shap_e_to_diffusers.py b/scripts/convert_shap_e_to_diffusers.py index 4a159fc61f..b6f5780d96 100644 --- a/scripts/convert_shap_e_to_diffusers.py +++ b/scripts/convert_shap_e_to_diffusers.py @@ -3,8 +3,10 @@ import tempfile import torch from accelerate import load_checkpoint_and_dispatch +from collections import OrderedDict from diffusers.models.prior_transformer import PriorTransformer +from diffusers.pipelines.shap_e import ShapEParamsProjModel """ @@ -19,8 +21,9 @@ Convert the model: ```sh $ python scripts/convert_shap_e_to_diffusers.py \ --prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \ - --dump_path /home/yiyi_huggingface_co/model_repo/shape \ - --debug prior + --params_proj_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\ + --dump_path /home/yiyi_huggingface_co/model_repo/shape/params_proj \ + --debug params_proj ``` """ @@ -216,6 +219,30 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix # done prior +# params_proj + +PARAMS_PROJ_ORIGINAL_PREFIX = "encoder.params_proj" + +PARAMS_PROJ_CONFIG = {} + +def params_proj_model_from_original_config(): + model = ShapEParamsProjModel(**PARAMS_PROJ_CONFIG) + + return model + + +def params_proj_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + + diffusers_checkpoint = { + k: checkpoint[f"{PARAMS_PROJ_ORIGINAL_PREFIX}.{k}"] for k in model.state_dict().keys() + } + + return diffusers_checkpoint + + +# done params_proj + + # TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) def split_attentions(*, weight, bias, split, chunk_size): weights = [None] * split @@ -267,6 +294,24 @@ def prior(*, args, checkpoint_map_location): return prior_model +def params_proj(*, args, checkpoint_map_location): + print("loading params_proj") + + params_proj_checkpoint = torch.load(args.params_proj_checkpoint_path, map_location=checkpoint_map_location) + + params_proj_model = params_proj_model_from_original_config() + + params_proj_diffusers_checkpoint = params_proj_original_checkpoint_to_diffusers_checkpoint(params_proj_model, params_proj_checkpoint) + + del params_proj_checkpoint + + load_checkpoint_to_model(params_proj_diffusers_checkpoint,params_proj_model, strict=True) + + print("done loading params_proj") + + return params_proj_model + + def load_checkpoint_to_model(checkpoint, model, strict=False): with tempfile.NamedTemporaryFile() as file: torch.save(checkpoint, file.name) @@ -286,7 +331,15 @@ if __name__ == "__main__": "--prior_checkpoint_path", default=None, type=str, - required=True, + required=False, + help="Path to the prior checkpoint to convert.", + ) + + parser.add_argument( + "--params_proj_checkpoint_path", + default=None, + type=str, + required=False, help="Path to the prior checkpoint to convert.", ) @@ -320,5 +373,8 @@ if __name__ == "__main__": elif args.debug == "prior": prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) prior_model.save_pretrained(args.dump_path) + elif args.debug == "params_proj": + params_proj_model = params_proj(args=args, checkpoint_map_location=checkpoint_map_location) + params_proj_model.save_pretrained(args.dump_path) else: raise ValueError(f"unknown debug value : {args.debug}") diff --git a/src/diffusers/pipelines/shap_e/__init__.py b/src/diffusers/pipelines/shap_e/__init__.py index bc8c04d50a..a9041801dc 100644 --- a/src/diffusers/pipelines/shap_e/__init__.py +++ b/src/diffusers/pipelines/shap_e/__init__.py @@ -13,3 +13,5 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline else: from .pipeline_shap_e import ShapEPipeline + from .params_proj import ShapEParamsProjModel + from .renderer import MLPNeRSTFModel diff --git a/src/diffusers/pipelines/shap_e/camera.py b/src/diffusers/pipelines/shap_e/camera.py new file mode 100644 index 0000000000..db92c41e37 --- /dev/null +++ b/src/diffusers/pipelines/shap_e/camera.py @@ -0,0 +1,142 @@ +from dataclasses import dataclass +from typing import Tuple + +import torch +import numpy as np + +@dataclass +class DifferentiableProjectiveCamera: + """ + Implements a batch, differentiable, standard pinhole camera + """ + + origin: torch.Tensor # [batch_size x 3] + x: torch.Tensor # [batch_size x 3] + y: torch.Tensor # [batch_size x 3] + z: torch.Tensor # [batch_size x 3] + width: int + height: int + x_fov: float + y_fov: float + + def __post_init__(self): + assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] + assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 + assert ( + len(self.x.shape) + == len(self.y.shape) + == len(self.z.shape) + == len(self.origin.shape) + == 2 + ) + + def resolution(self): + return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) + + def fov(self): + return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) + + def image_coords(self) -> torch.Tensor: + """ + :return: coords of shape (width * height, 2) + """ + pixel_indices = torch.arange(self.height * self.width) + coords = torch.stack( + [ + pixel_indices % self.width, + torch.div(pixel_indices, self.width, rounding_mode="trunc"), + ], + axis=1, + ) + return coords + + def camera_rays(self, coords: torch.Tensor) -> torch.Tensor: + batch_size, *shape, n_coords = coords.shape + assert n_coords == 2 + assert batch_size == self.origin.shape[0] + flat = coords.view(batch_size, -1, 2) + + res = self.resolution().to(flat.device) + fov = self.fov().to(flat.device) + + fracs = (flat.float() / (res - 1)) * 2 - 1 + fracs = fracs * torch.tan(fov / 2) + + fracs = fracs.view(batch_size, -1, 2) + directions = ( + self.z.view(batch_size, 1, 3) + + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] + ) + directions = directions / directions.norm(dim=-1, keepdim=True) + rays = torch.stack( + [ + torch.broadcast_to( + self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3] + ), + directions, + ], + dim=2, + ) + return rays.view(batch_size, *shape, 2, 3) + + def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": + """ + Creates a new camera for the resized view assuming the aspect ratio does not change. + """ + assert width * self.height == height * self.width, "The aspect ratio should not change." + return DifferentiableProjectiveCamera( + origin=self.origin, + x=self.x, + y=self.y, + z=self.z, + width=width, + height=height, + x_fov=self.x_fov, + y_fov=self.y_fov, + ) + +@dataclass +class DifferentiableCameraBatch: + """ + Annotate a differentiable camera with a multi-dimensional batch shape. + """ + + shape: Tuple[int] + flat_camera: DifferentiableProjectiveCamera + +def create_pan_cameras(size: int, device: torch.device) -> DifferentiableCameraBatch: + origins = [] + xs = [] + ys = [] + zs = [] + for theta in np.linspace(0, 2 * np.pi, num=20): + z = np.array([np.sin(theta), np.cos(theta), -0.5]) + z /= np.sqrt(np.sum(z**2)) + origin = -z * 4 + x = np.array([np.cos(theta), -np.sin(theta), 0.0]) + y = np.cross(z, x) + origins.append(origin) + xs.append(x) + ys.append(y) + zs.append(z) + return DifferentiableCameraBatch( + shape=(1, len(xs)), + flat_camera=DifferentiableProjectiveCamera( + origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device), + x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), + y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), + z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), + width=size, + height=size, + x_fov=0.7, + y_fov=0.7, + ), + ) + +def get_image_coords(width, height) -> torch.Tensor: + pixel_indices = torch.arange(height * width) + # torch throws warnings for pixel_indices // width + pixel_indices_div = torch.div(pixel_indices, width, rounding_mode="trunc") + coords = torch.stack([pixel_indices % width, pixel_indices_div], dim=1) + return coords \ No newline at end of file diff --git a/src/diffusers/pipelines/shap_e/params_proj.py b/src/diffusers/pipelines/shap_e/params_proj.py new file mode 100644 index 0000000000..1910889bc6 --- /dev/null +++ b/src/diffusers/pipelines/shap_e/params_proj.py @@ -0,0 +1,96 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from typing import Tuple, Optional +from collections import OrderedDict + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + +class ChannelsProj(nn.Module): + def __init__( + self, + *, + vectors: int, + channels: int, + d_latent: int, + ): + super().__init__() + self.proj = nn.Linear(d_latent, vectors * channels) + self.norm = nn.LayerNorm(channels) + self.d_latent = d_latent + self.vectors = vectors + self.channels = channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_bvd = x + w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) + b_vc = self.proj.bias.view(1, self.vectors, self.channels) + h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) + h = self.norm(h) + + h = h + b_vc + return h + + +class ShapEParamsProjModel(ModelMixin, ConfigMixin): + """ + project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). + + For more details, see the original paper: + """ + @register_to_config + def __init__( + self, + *, + param_names: Tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: Tuple[Tuple[int]] = ((256, 93),(256, 256),(256, 256),(256, 256),), + d_latent: int = 1024, + ): + super().__init__() + + # check inputs + if len(param_names) != len(param_shapes): + raise ValueError( + f"Must provide same number of `param_names` as `param_shapes`" + ) + self.projections = nn.ModuleDict({}) + for k, (vectors, channels) in zip(param_names, param_shapes): + self.projections[_sanitize_name(k)] = ChannelsProj( + vectors=vectors, + channels=channels, + d_latent=d_latent, + ) + + def forward(self, x: torch.Tensor): + out = {} + start = 0 + for k, shape in zip(self.config.param_names, self.config.param_shapes): + vectors, _ = shape + end = start + vectors + x_bvd = x[:, start:end] + out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) + start = end + return out + +def _sanitize_name(x: str) -> str: + return x.replace(".", "__") \ No newline at end of file diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py new file mode 100644 index 0000000000..45530b1df5 --- /dev/null +++ b/src/diffusers/pipelines/shap_e/renderer.py @@ -0,0 +1,103 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +import math + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + +def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: + """ + Concatenate x and its positional encodings, following NeRF. + + Reference: https://arxiv.org/pdf/2210.04628.pdf + """ + if min_deg == max_deg: + return x + scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype) + *shape, dim = x.shape + xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) + assert xb.shape[-1] == dim * (max_deg - min_deg) + emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() + return torch.cat([x, emb], dim=-1) + +def encode_position(position): + + return posenc_nerf(position, min_deg=0, max_deg=15) + +def encode_direction(position, direction=None): + if direction is None: + return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) + else: + return posenc_nerf(direction, min_deg=0, max_deg=8) + +class MLPNeRSTFModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + + ): + super().__init__() + # Instantiate the MLP + + # Find out the dimension of encoded position and direction + dummy = torch.eye(1, 3) + d_posenc_pos = encode_position(position=dummy).shape[-1] + d_posenc_dir = encode_direction(position=dummy).shape[-1] + + mlp_widths = [d_hidden] * n_hidden_layers + input_widths = [d_posenc_pos] + mlp_widths + output_widths = mlp_widths + [n_output] + + if insert_direction_at is not None: + input_widths[insert_direction_at] += d_posenc_dir + + self.mlp = nn.ModuleList( + [ + nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths) + ] + ) + + if act_fn == "swish": + self.activation = lambda x: torch.sigmoid(x) + else: + raise ValueError(f"Unsupported activation function {act_fn}") + + + def forward(self, *, positions, directions): + + h = encode_position(position) + h_preact = h + h_directionless = None + for i, layer in enumerate(self.mlp): + if i == self.config.insert_direction_at: # 4 in the config + h_directionless = h_preact + h_direction = encode_direction(position, direction=direction) + h = torch.cat([h, h_direction], dim=-1) + + h = layer(h) + h_preact = h + if i < len(self.mlp) - 1: + h = self.activation(h) + h_final = h + if h_directionless is None: + h_directionless = h_preact + return h_final, h_directionless \ No newline at end of file