1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
yiyixuxu
2023-06-18 19:45:29 +00:00
parent 20e5be74d8
commit 9324a54c67
5 changed files with 402 additions and 3 deletions

View File

@@ -3,8 +3,10 @@ import tempfile
import torch
from accelerate import load_checkpoint_and_dispatch
from collections import OrderedDict
from diffusers.models.prior_transformer import PriorTransformer
from diffusers.pipelines.shap_e import ShapEParamsProjModel
"""
@@ -19,8 +21,9 @@ Convert the model:
```sh
$ python scripts/convert_shap_e_to_diffusers.py \
--prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \
--dump_path /home/yiyi_huggingface_co/model_repo/shape \
--debug prior
--params_proj_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\
--dump_path /home/yiyi_huggingface_co/model_repo/shape/params_proj \
--debug params_proj
```
"""
@@ -216,6 +219,30 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix
# done prior
# params_proj
PARAMS_PROJ_ORIGINAL_PREFIX = "encoder.params_proj"
PARAMS_PROJ_CONFIG = {}
def params_proj_model_from_original_config():
model = ShapEParamsProjModel(**PARAMS_PROJ_CONFIG)
return model
def params_proj_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {
k: checkpoint[f"{PARAMS_PROJ_ORIGINAL_PREFIX}.{k}"] for k in model.state_dict().keys()
}
return diffusers_checkpoint
# done params_proj
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
def split_attentions(*, weight, bias, split, chunk_size):
weights = [None] * split
@@ -267,6 +294,24 @@ def prior(*, args, checkpoint_map_location):
return prior_model
def params_proj(*, args, checkpoint_map_location):
print("loading params_proj")
params_proj_checkpoint = torch.load(args.params_proj_checkpoint_path, map_location=checkpoint_map_location)
params_proj_model = params_proj_model_from_original_config()
params_proj_diffusers_checkpoint = params_proj_original_checkpoint_to_diffusers_checkpoint(params_proj_model, params_proj_checkpoint)
del params_proj_checkpoint
load_checkpoint_to_model(params_proj_diffusers_checkpoint,params_proj_model, strict=True)
print("done loading params_proj")
return params_proj_model
def load_checkpoint_to_model(checkpoint, model, strict=False):
with tempfile.NamedTemporaryFile() as file:
torch.save(checkpoint, file.name)
@@ -286,7 +331,15 @@ if __name__ == "__main__":
"--prior_checkpoint_path",
default=None,
type=str,
required=True,
required=False,
help="Path to the prior checkpoint to convert.",
)
parser.add_argument(
"--params_proj_checkpoint_path",
default=None,
type=str,
required=False,
help="Path to the prior checkpoint to convert.",
)
@@ -320,5 +373,8 @@ if __name__ == "__main__":
elif args.debug == "prior":
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
prior_model.save_pretrained(args.dump_path)
elif args.debug == "params_proj":
params_proj_model = params_proj(args=args, checkpoint_map_location=checkpoint_map_location)
params_proj_model.save_pretrained(args.dump_path)
else:
raise ValueError(f"unknown debug value : {args.debug}")

View File

@@ -13,3 +13,5 @@ except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
else:
from .pipeline_shap_e import ShapEPipeline
from .params_proj import ShapEParamsProjModel
from .renderer import MLPNeRSTFModel

View File

@@ -0,0 +1,142 @@
from dataclasses import dataclass
from typing import Tuple
import torch
import numpy as np
@dataclass
class DifferentiableProjectiveCamera:
"""
Implements a batch, differentiable, standard pinhole camera
"""
origin: torch.Tensor # [batch_size x 3]
x: torch.Tensor # [batch_size x 3]
y: torch.Tensor # [batch_size x 3]
z: torch.Tensor # [batch_size x 3]
width: int
height: int
x_fov: float
y_fov: float
def __post_init__(self):
assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
assert (
len(self.x.shape)
== len(self.y.shape)
== len(self.z.shape)
== len(self.origin.shape)
== 2
)
def resolution(self):
return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
def fov(self):
return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
def image_coords(self) -> torch.Tensor:
"""
:return: coords of shape (width * height, 2)
"""
pixel_indices = torch.arange(self.height * self.width)
coords = torch.stack(
[
pixel_indices % self.width,
torch.div(pixel_indices, self.width, rounding_mode="trunc"),
],
axis=1,
)
return coords
def camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
batch_size, *shape, n_coords = coords.shape
assert n_coords == 2
assert batch_size == self.origin.shape[0]
flat = coords.view(batch_size, -1, 2)
res = self.resolution().to(flat.device)
fov = self.fov().to(flat.device)
fracs = (flat.float() / (res - 1)) * 2 - 1
fracs = fracs * torch.tan(fov / 2)
fracs = fracs.view(batch_size, -1, 2)
directions = (
self.z.view(batch_size, 1, 3)
+ self.x.view(batch_size, 1, 3) * fracs[:, :, :1]
+ self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]
)
directions = directions / directions.norm(dim=-1, keepdim=True)
rays = torch.stack(
[
torch.broadcast_to(
self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]
),
directions,
],
dim=2,
)
return rays.view(batch_size, *shape, 2, 3)
def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
"""
Creates a new camera for the resized view assuming the aspect ratio does not change.
"""
assert width * self.height == height * self.width, "The aspect ratio should not change."
return DifferentiableProjectiveCamera(
origin=self.origin,
x=self.x,
y=self.y,
z=self.z,
width=width,
height=height,
x_fov=self.x_fov,
y_fov=self.y_fov,
)
@dataclass
class DifferentiableCameraBatch:
"""
Annotate a differentiable camera with a multi-dimensional batch shape.
"""
shape: Tuple[int]
flat_camera: DifferentiableProjectiveCamera
def create_pan_cameras(size: int, device: torch.device) -> DifferentiableCameraBatch:
origins = []
xs = []
ys = []
zs = []
for theta in np.linspace(0, 2 * np.pi, num=20):
z = np.array([np.sin(theta), np.cos(theta), -0.5])
z /= np.sqrt(np.sum(z**2))
origin = -z * 4
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
y = np.cross(z, x)
origins.append(origin)
xs.append(x)
ys.append(y)
zs.append(z)
return DifferentiableCameraBatch(
shape=(1, len(xs)),
flat_camera=DifferentiableProjectiveCamera(
origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device),
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
width=size,
height=size,
x_fov=0.7,
y_fov=0.7,
),
)
def get_image_coords(width, height) -> torch.Tensor:
pixel_indices = torch.arange(height * width)
# torch throws warnings for pixel_indices // width
pixel_indices_div = torch.div(pixel_indices, width, rounding_mode="trunc")
coords = torch.stack([pixel_indices % width, pixel_indices_div], dim=1)
return coords

View File

@@ -0,0 +1,96 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from typing import Tuple, Optional
from collections import OrderedDict
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
class ChannelsProj(nn.Module):
def __init__(
self,
*,
vectors: int,
channels: int,
d_latent: int,
):
super().__init__()
self.proj = nn.Linear(d_latent, vectors * channels)
self.norm = nn.LayerNorm(channels)
self.d_latent = d_latent
self.vectors = vectors
self.channels = channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_bvd = x
w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
b_vc = self.proj.bias.view(1, self.vectors, self.channels)
h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
h = self.norm(h)
h = h + b_vc
return h
class ShapEParamsProjModel(ModelMixin, ConfigMixin):
"""
project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP).
For more details, see the original paper:
"""
@register_to_config
def __init__(
self,
*,
param_names: Tuple[str] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
"nerstf.mlp.3.weight",
),
param_shapes: Tuple[Tuple[int]] = ((256, 93),(256, 256),(256, 256),(256, 256),),
d_latent: int = 1024,
):
super().__init__()
# check inputs
if len(param_names) != len(param_shapes):
raise ValueError(
f"Must provide same number of `param_names` as `param_shapes`"
)
self.projections = nn.ModuleDict({})
for k, (vectors, channels) in zip(param_names, param_shapes):
self.projections[_sanitize_name(k)] = ChannelsProj(
vectors=vectors,
channels=channels,
d_latent=d_latent,
)
def forward(self, x: torch.Tensor):
out = {}
start = 0
for k, shape in zip(self.config.param_names, self.config.param_shapes):
vectors, _ = shape
end = start + vectors
x_bvd = x[:, start:end]
out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
start = end
return out
def _sanitize_name(x: str) -> str:
return x.replace(".", "__")

View File

@@ -0,0 +1,103 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
import math
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:
"""
Concatenate x and its positional encodings, following NeRF.
Reference: https://arxiv.org/pdf/2210.04628.pdf
"""
if min_deg == max_deg:
return x
scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype)
*shape, dim = x.shape
xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)
assert xb.shape[-1] == dim * (max_deg - min_deg)
emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()
return torch.cat([x, emb], dim=-1)
def encode_position(position):
return posenc_nerf(position, min_deg=0, max_deg=15)
def encode_direction(position, direction=None):
if direction is None:
return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))
else:
return posenc_nerf(direction, min_deg=0, max_deg=8)
class MLPNeRSTFModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
d_hidden: int = 256,
n_output: int = 12,
n_hidden_layers: int = 6,
act_fn: str = "swish",
insert_direction_at: int = 4,
):
super().__init__()
# Instantiate the MLP
# Find out the dimension of encoded position and direction
dummy = torch.eye(1, 3)
d_posenc_pos = encode_position(position=dummy).shape[-1]
d_posenc_dir = encode_direction(position=dummy).shape[-1]
mlp_widths = [d_hidden] * n_hidden_layers
input_widths = [d_posenc_pos] + mlp_widths
output_widths = mlp_widths + [n_output]
if insert_direction_at is not None:
input_widths[insert_direction_at] += d_posenc_dir
self.mlp = nn.ModuleList(
[
nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)
]
)
if act_fn == "swish":
self.activation = lambda x: torch.sigmoid(x)
else:
raise ValueError(f"Unsupported activation function {act_fn}")
def forward(self, *, positions, directions):
h = encode_position(position)
h_preact = h
h_directionless = None
for i, layer in enumerate(self.mlp):
if i == self.config.insert_direction_at: # 4 in the config
h_directionless = h_preact
h_direction = encode_direction(position, direction=direction)
h = torch.cat([h, h_direction], dim=-1)
h = layer(h)
h_preact = h
if i < len(self.mlp) - 1:
h = self.activation(h)
h_final = h
if h_directionless is None:
h_directionless = h_preact
return h_final, h_directionless