1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add from_pt argument in .from_pretrained (#527)

* first commit:

- add `from_pt` argument in `from_pretrained` function
- add `modeling_flax_pytorch_utils.py` file

* small nit

- fix a small nit - to not enter in the second if condition

* major changes

- modify FlaxUnet modules
- first conversion script
- more keys to be matched

* keys match

- now all keys match
- change module names for correct matching
- upsample module name changed

* working v1

- test pass with atol and rtol= `4e-02`

* replace unsued arg

* make quality

* add small docstring

* add more comments

- add TODO for embedding layers

* small change

- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array

* add more conditions on conversion

- add better test to check for keys conversion

* make shapes consistent

- output `img_w x img_h x n_channels` from the VAE

* Revert "make shapes consistent"

This reverts commit 4cad1aeb4a.

* fix unet shape

- channels first!
This commit is contained in:
Younes Belkada
2022-09-20 12:39:25 +02:00
committed by GitHub
parent ca74951323
commit 0902449ef8
5 changed files with 199 additions and 43 deletions

View File

@@ -0,0 +1,117 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - Flax general utilities."""
import re
import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from .utils import logging
logger = logging.get_logger(__name__)
def rename_key(key):
regex = r"\w+[.]\d+"
pats = re.findall(regex, key)
for pat in pats:
key = key.replace(pat, "_".join(pat.split(".")))
return key
#####################
# PyTorch => Flax #
#####################
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
):
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor
# embedding
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
return renamed_pt_tuple_key, pt_tensor
# conv layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor
# linear layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight":
pt_tensor = pt_tensor.T
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm weight
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
if pt_tuple_key[-1] == "gamma":
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm bias
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor
return pt_tuple_key, pt_tensor
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
# Step 1: Convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
# Step 2: Since the model is stateless, get random Flax params
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
random_flax_state_dict = flatten_dict(random_flax_params)
flax_state_dict = {}
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
renamed_pt_key = rename_key(pt_key)
pt_tuple_key = tuple(renamed_pt_key.split("."))
# Correctly rename weight parameters
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
if flax_key in random_flax_state_dict:
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict)

View File

@@ -27,7 +27,8 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from .modeling_utils import WEIGHTS_NAME
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import WEIGHTS_NAME, load_state_dict
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
@@ -245,6 +246,8 @@ class FlaxModelMixin:
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
from_pt (`bool`, *optional*, defaults to `False`):
Load the model weights from a PyTorch checkpoint save file.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@@ -272,6 +275,7 @@ class FlaxModelMixin:
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
from_pt = kwargs.pop("from_pt", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
@@ -306,10 +310,16 @@ class FlaxModelMixin:
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif from_pt:
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
)
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights."
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
" using `from_pt=True`."
)
else:
raise EnvironmentError(
@@ -320,7 +330,7 @@ class FlaxModelMixin:
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_NAME,
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
@@ -370,25 +380,32 @@ class FlaxModelMixin:
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
try:
with open(model_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
if from_pt:
# Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file)
# Step 2: Convert the weights
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
else:
try:
with open(model_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
with open(model_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(model_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
# flatten dicts

View File

@@ -32,7 +32,7 @@ class FlaxAttentionBlock(nn.Module):
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
@@ -82,9 +82,9 @@ class FlaxBasicTransformerBlock(nn.Module):
def setup(self):
# self attention
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
@@ -93,12 +93,12 @@ class FlaxBasicTransformerBlock(nn.Module):
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
# cross attention
residual = hidden_states
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = hidden_states + residual
# feed forward
@@ -168,13 +168,27 @@ class FlaxGluFeedForward(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
inner_dim = self.dim * 4
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.dense1(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
hidden_states = self.dense2(hidden_states)
hidden_states = self.net_0(hidden_states)
hidden_states = self.net_2(hidden_states)
return hidden_states
class FlaxGEGLU(nn.Module):
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
return hidden_linear * nn.gelu(hidden_gelu)

View File

@@ -76,7 +76,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensors
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
timesteps = jnp.ones((1,), dtype=jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
@@ -214,10 +214,17 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
When returning a tuple, the first element is the sample tensor.
"""
# 1. time
if not isinstance(timesteps, jnp.ndarray):
timesteps = jnp.array([timesteps], dtype=jnp.int32)
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
timesteps = timesteps.astype(dtype=jnp.float32)
timesteps = jnp.expand_dims(timesteps, 0)
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
# 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)
# 3. down
@@ -251,6 +258,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample = self.conv_norm_out(sample)
sample = nn.silu(sample)
sample = self.conv_out(sample)
sample = jnp.transpose(sample, (0, 3, 1, 2))
if not return_dict:
return (sample,)

View File

@@ -55,7 +55,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
self.attentions = attentions
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
output_states = ()
@@ -66,7 +66,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
@@ -96,7 +96,7 @@ class FlaxDownBlock2D(nn.Module):
self.resnets = resnets
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, deterministic=True):
output_states = ()
@@ -106,7 +106,7 @@ class FlaxDownBlock2D(nn.Module):
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
@@ -151,7 +151,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
self.attentions = attentions
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
for resnet, attn in zip(self.resnets, self.attentions):
@@ -164,7 +164,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states
@@ -196,7 +196,7 @@ class FlaxUpBlock2D(nn.Module):
self.resnets = resnets
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
for resnet in self.resnets:
@@ -208,7 +208,7 @@ class FlaxUpBlock2D(nn.Module):
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states