mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add
This commit is contained in:
@@ -3,8 +3,10 @@ import tempfile
|
||||
|
||||
import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
from collections import OrderedDict
|
||||
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.pipelines.shap_e import ShapEParamsProjModel
|
||||
|
||||
|
||||
"""
|
||||
@@ -19,8 +21,9 @@ Convert the model:
|
||||
```sh
|
||||
$ python scripts/convert_shap_e_to_diffusers.py \
|
||||
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
|
||||
--dump_path /home/yiyi_huggingface_co/model_repo/shape \
|
||||
--debug prior
|
||||
--params_proj_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\
|
||||
--dump_path /home/yiyi_huggingface_co/model_repo/shape/params_proj \
|
||||
--debug params_proj
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -216,6 +219,30 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix
|
||||
# done prior
|
||||
|
||||
|
||||
# params_proj
|
||||
|
||||
PARAMS_PROJ_ORIGINAL_PREFIX = "encoder.params_proj"
|
||||
|
||||
PARAMS_PROJ_CONFIG = {}
|
||||
|
||||
def params_proj_model_from_original_config():
|
||||
model = ShapEParamsProjModel(**PARAMS_PROJ_CONFIG)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def params_proj_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
|
||||
diffusers_checkpoint = {
|
||||
k: checkpoint[f"{PARAMS_PROJ_ORIGINAL_PREFIX}.{k}"] for k in model.state_dict().keys()
|
||||
}
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# done params_proj
|
||||
|
||||
|
||||
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
|
||||
def split_attentions(*, weight, bias, split, chunk_size):
|
||||
weights = [None] * split
|
||||
@@ -267,6 +294,24 @@ def prior(*, args, checkpoint_map_location):
|
||||
return prior_model
|
||||
|
||||
|
||||
def params_proj(*, args, checkpoint_map_location):
|
||||
print("loading params_proj")
|
||||
|
||||
params_proj_checkpoint = torch.load(args.params_proj_checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
params_proj_model = params_proj_model_from_original_config()
|
||||
|
||||
params_proj_diffusers_checkpoint = params_proj_original_checkpoint_to_diffusers_checkpoint(params_proj_model, params_proj_checkpoint)
|
||||
|
||||
del params_proj_checkpoint
|
||||
|
||||
load_checkpoint_to_model(params_proj_diffusers_checkpoint,params_proj_model, strict=True)
|
||||
|
||||
print("done loading params_proj")
|
||||
|
||||
return params_proj_model
|
||||
|
||||
|
||||
def load_checkpoint_to_model(checkpoint, model, strict=False):
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
torch.save(checkpoint, file.name)
|
||||
@@ -286,7 +331,15 @@ if __name__ == "__main__":
|
||||
"--prior_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
required=False,
|
||||
help="Path to the prior checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--params_proj_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the prior checkpoint to convert.",
|
||||
)
|
||||
|
||||
@@ -320,5 +373,8 @@ if __name__ == "__main__":
|
||||
elif args.debug == "prior":
|
||||
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
prior_model.save_pretrained(args.dump_path)
|
||||
elif args.debug == "params_proj":
|
||||
params_proj_model = params_proj(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
params_proj_model.save_pretrained(args.dump_path)
|
||||
else:
|
||||
raise ValueError(f"unknown debug value : {args.debug}")
|
||||
|
||||
@@ -13,3 +13,5 @@ except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
|
||||
else:
|
||||
from .pipeline_shap_e import ShapEPipeline
|
||||
from .params_proj import ShapEParamsProjModel
|
||||
from .renderer import MLPNeRSTFModel
|
||||
|
||||
142
src/diffusers/pipelines/shap_e/camera.py
Normal file
142
src/diffusers/pipelines/shap_e/camera.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@dataclass
|
||||
class DifferentiableProjectiveCamera:
|
||||
"""
|
||||
Implements a batch, differentiable, standard pinhole camera
|
||||
"""
|
||||
|
||||
origin: torch.Tensor # [batch_size x 3]
|
||||
x: torch.Tensor # [batch_size x 3]
|
||||
y: torch.Tensor # [batch_size x 3]
|
||||
z: torch.Tensor # [batch_size x 3]
|
||||
width: int
|
||||
height: int
|
||||
x_fov: float
|
||||
y_fov: float
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
|
||||
assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
|
||||
assert (
|
||||
len(self.x.shape)
|
||||
== len(self.y.shape)
|
||||
== len(self.z.shape)
|
||||
== len(self.origin.shape)
|
||||
== 2
|
||||
)
|
||||
|
||||
def resolution(self):
|
||||
return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
|
||||
|
||||
def fov(self):
|
||||
return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
|
||||
|
||||
def image_coords(self) -> torch.Tensor:
|
||||
"""
|
||||
:return: coords of shape (width * height, 2)
|
||||
"""
|
||||
pixel_indices = torch.arange(self.height * self.width)
|
||||
coords = torch.stack(
|
||||
[
|
||||
pixel_indices % self.width,
|
||||
torch.div(pixel_indices, self.width, rounding_mode="trunc"),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
return coords
|
||||
|
||||
def camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, *shape, n_coords = coords.shape
|
||||
assert n_coords == 2
|
||||
assert batch_size == self.origin.shape[0]
|
||||
flat = coords.view(batch_size, -1, 2)
|
||||
|
||||
res = self.resolution().to(flat.device)
|
||||
fov = self.fov().to(flat.device)
|
||||
|
||||
fracs = (flat.float() / (res - 1)) * 2 - 1
|
||||
fracs = fracs * torch.tan(fov / 2)
|
||||
|
||||
fracs = fracs.view(batch_size, -1, 2)
|
||||
directions = (
|
||||
self.z.view(batch_size, 1, 3)
|
||||
+ self.x.view(batch_size, 1, 3) * fracs[:, :, :1]
|
||||
+ self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]
|
||||
)
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True)
|
||||
rays = torch.stack(
|
||||
[
|
||||
torch.broadcast_to(
|
||||
self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]
|
||||
),
|
||||
directions,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
return rays.view(batch_size, *shape, 2, 3)
|
||||
|
||||
def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
|
||||
"""
|
||||
Creates a new camera for the resized view assuming the aspect ratio does not change.
|
||||
"""
|
||||
assert width * self.height == height * self.width, "The aspect ratio should not change."
|
||||
return DifferentiableProjectiveCamera(
|
||||
origin=self.origin,
|
||||
x=self.x,
|
||||
y=self.y,
|
||||
z=self.z,
|
||||
width=width,
|
||||
height=height,
|
||||
x_fov=self.x_fov,
|
||||
y_fov=self.y_fov,
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class DifferentiableCameraBatch:
|
||||
"""
|
||||
Annotate a differentiable camera with a multi-dimensional batch shape.
|
||||
"""
|
||||
|
||||
shape: Tuple[int]
|
||||
flat_camera: DifferentiableProjectiveCamera
|
||||
|
||||
def create_pan_cameras(size: int, device: torch.device) -> DifferentiableCameraBatch:
|
||||
origins = []
|
||||
xs = []
|
||||
ys = []
|
||||
zs = []
|
||||
for theta in np.linspace(0, 2 * np.pi, num=20):
|
||||
z = np.array([np.sin(theta), np.cos(theta), -0.5])
|
||||
z /= np.sqrt(np.sum(z**2))
|
||||
origin = -z * 4
|
||||
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
|
||||
y = np.cross(z, x)
|
||||
origins.append(origin)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
zs.append(z)
|
||||
return DifferentiableCameraBatch(
|
||||
shape=(1, len(xs)),
|
||||
flat_camera=DifferentiableProjectiveCamera(
|
||||
origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device),
|
||||
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
|
||||
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
|
||||
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
|
||||
width=size,
|
||||
height=size,
|
||||
x_fov=0.7,
|
||||
y_fov=0.7,
|
||||
),
|
||||
)
|
||||
|
||||
def get_image_coords(width, height) -> torch.Tensor:
|
||||
pixel_indices = torch.arange(height * width)
|
||||
# torch throws warnings for pixel_indices // width
|
||||
pixel_indices_div = torch.div(pixel_indices, width, rounding_mode="trunc")
|
||||
coords = torch.stack([pixel_indices % width, pixel_indices_div], dim=1)
|
||||
return coords
|
||||
96
src/diffusers/pipelines/shap_e/params_proj.py
Normal file
96
src/diffusers/pipelines/shap_e/params_proj.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from typing import Tuple, Optional
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
|
||||
class ChannelsProj(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vectors: int,
|
||||
channels: int,
|
||||
d_latent: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(d_latent, vectors * channels)
|
||||
self.norm = nn.LayerNorm(channels)
|
||||
self.d_latent = d_latent
|
||||
self.vectors = vectors
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_bvd = x
|
||||
w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
|
||||
b_vc = self.proj.bias.view(1, self.vectors, self.channels)
|
||||
h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
|
||||
h = self.norm(h)
|
||||
|
||||
h = h + b_vc
|
||||
return h
|
||||
|
||||
|
||||
class ShapEParamsProjModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP).
|
||||
|
||||
For more details, see the original paper:
|
||||
"""
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
param_names: Tuple[str] = (
|
||||
"nerstf.mlp.0.weight",
|
||||
"nerstf.mlp.1.weight",
|
||||
"nerstf.mlp.2.weight",
|
||||
"nerstf.mlp.3.weight",
|
||||
),
|
||||
param_shapes: Tuple[Tuple[int]] = ((256, 93),(256, 256),(256, 256),(256, 256),),
|
||||
d_latent: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# check inputs
|
||||
if len(param_names) != len(param_shapes):
|
||||
raise ValueError(
|
||||
f"Must provide same number of `param_names` as `param_shapes`"
|
||||
)
|
||||
self.projections = nn.ModuleDict({})
|
||||
for k, (vectors, channels) in zip(param_names, param_shapes):
|
||||
self.projections[_sanitize_name(k)] = ChannelsProj(
|
||||
vectors=vectors,
|
||||
channels=channels,
|
||||
d_latent=d_latent,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
out = {}
|
||||
start = 0
|
||||
for k, shape in zip(self.config.param_names, self.config.param_shapes):
|
||||
vectors, _ = shape
|
||||
end = start + vectors
|
||||
x_bvd = x[:, start:end]
|
||||
out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
|
||||
start = end
|
||||
return out
|
||||
|
||||
def _sanitize_name(x: str) -> str:
|
||||
return x.replace(".", "__")
|
||||
103
src/diffusers/pipelines/shap_e/renderer.py
Normal file
103
src/diffusers/pipelines/shap_e/renderer.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
|
||||
def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:
|
||||
"""
|
||||
Concatenate x and its positional encodings, following NeRF.
|
||||
|
||||
Reference: https://arxiv.org/pdf/2210.04628.pdf
|
||||
"""
|
||||
if min_deg == max_deg:
|
||||
return x
|
||||
scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype)
|
||||
*shape, dim = x.shape
|
||||
xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)
|
||||
assert xb.shape[-1] == dim * (max_deg - min_deg)
|
||||
emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()
|
||||
return torch.cat([x, emb], dim=-1)
|
||||
|
||||
def encode_position(position):
|
||||
|
||||
return posenc_nerf(position, min_deg=0, max_deg=15)
|
||||
|
||||
def encode_direction(position, direction=None):
|
||||
if direction is None:
|
||||
return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))
|
||||
else:
|
||||
return posenc_nerf(direction, min_deg=0, max_deg=8)
|
||||
|
||||
class MLPNeRSTFModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
d_hidden: int = 256,
|
||||
n_output: int = 12,
|
||||
n_hidden_layers: int = 6,
|
||||
act_fn: str = "swish",
|
||||
insert_direction_at: int = 4,
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
# Instantiate the MLP
|
||||
|
||||
# Find out the dimension of encoded position and direction
|
||||
dummy = torch.eye(1, 3)
|
||||
d_posenc_pos = encode_position(position=dummy).shape[-1]
|
||||
d_posenc_dir = encode_direction(position=dummy).shape[-1]
|
||||
|
||||
mlp_widths = [d_hidden] * n_hidden_layers
|
||||
input_widths = [d_posenc_pos] + mlp_widths
|
||||
output_widths = mlp_widths + [n_output]
|
||||
|
||||
if insert_direction_at is not None:
|
||||
input_widths[insert_direction_at] += d_posenc_dir
|
||||
|
||||
self.mlp = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)
|
||||
]
|
||||
)
|
||||
|
||||
if act_fn == "swish":
|
||||
self.activation = lambda x: torch.sigmoid(x)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function {act_fn}")
|
||||
|
||||
|
||||
def forward(self, *, positions, directions):
|
||||
|
||||
h = encode_position(position)
|
||||
h_preact = h
|
||||
h_directionless = None
|
||||
for i, layer in enumerate(self.mlp):
|
||||
if i == self.config.insert_direction_at: # 4 in the config
|
||||
h_directionless = h_preact
|
||||
h_direction = encode_direction(position, direction=direction)
|
||||
h = torch.cat([h, h_direction], dim=-1)
|
||||
|
||||
h = layer(h)
|
||||
h_preact = h
|
||||
if i < len(self.mlp) - 1:
|
||||
h = self.activation(h)
|
||||
h_final = h
|
||||
if h_directionless is None:
|
||||
h_directionless = h_preact
|
||||
return h_final, h_directionless
|
||||
Reference in New Issue
Block a user