mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add from_pt argument in .from_pretrained (#527)
* first commit:
- add `from_pt` argument in `from_pretrained` function
- add `modeling_flax_pytorch_utils.py` file
* small nit
- fix a small nit - to not enter in the second if condition
* major changes
- modify FlaxUnet modules
- first conversion script
- more keys to be matched
* keys match
- now all keys match
- change module names for correct matching
- upsample module name changed
* working v1
- test pass with atol and rtol= `4e-02`
* replace unsued arg
* make quality
* add small docstring
* add more comments
- add TODO for embedding layers
* small change
- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
* add more conditions on conversion
- add better test to check for keys conversion
* make shapes consistent
- output `img_w x img_h x n_channels` from the VAE
* Revert "make shapes consistent"
This reverts commit 4cad1aeb4a.
* fix unet shape
- channels first!
This commit is contained in:
117
src/diffusers/modeling_flax_pytorch_utils.py
Normal file
117
src/diffusers/modeling_flax_pytorch_utils.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch - Flax general utilities."""
|
||||
import re
|
||||
|
||||
import jax.numpy as jnp
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def rename_key(key):
|
||||
regex = r"\w+[.]\d+"
|
||||
pats = re.findall(regex, key)
|
||||
for pat in pats:
|
||||
key = key.replace(pat, "_".join(pat.split(".")))
|
||||
return key
|
||||
|
||||
|
||||
#####################
|
||||
# PyTorch => Flax #
|
||||
#####################
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
|
||||
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
|
||||
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
|
||||
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
|
||||
|
||||
# conv norm or layer norm
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
if (
|
||||
any("norm" in str_ for str_ in pt_tuple_key)
|
||||
and (pt_tuple_key[-1] == "bias")
|
||||
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
|
||||
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
|
||||
):
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# embedding
|
||||
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# conv layer
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
|
||||
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# linear layer
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
if pt_tuple_key[-1] == "weight":
|
||||
pt_tensor = pt_tensor.T
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# old PyTorch layer norm weight
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
||||
if pt_tuple_key[-1] == "gamma":
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# old PyTorch layer norm bias
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
||||
if pt_tuple_key[-1] == "beta":
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
return pt_tuple_key, pt_tensor
|
||||
|
||||
|
||||
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
|
||||
# Step 1: Convert pytorch tensor to numpy
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
# Step 2: Since the model is stateless, get random Flax params
|
||||
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
|
||||
|
||||
random_flax_state_dict = flatten_dict(random_flax_params)
|
||||
flax_state_dict = {}
|
||||
|
||||
# Need to change some parameters name to match Flax names
|
||||
for pt_key, pt_tensor in pt_state_dict.items():
|
||||
renamed_pt_key = rename_key(pt_key)
|
||||
pt_tuple_key = tuple(renamed_pt_key.split("."))
|
||||
|
||||
# Correctly rename weight parameters
|
||||
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
|
||||
|
||||
if flax_key in random_flax_state_dict:
|
||||
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
||||
raise ValueError(
|
||||
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
||||
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||
)
|
||||
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||
|
||||
return unflatten_dict(flax_state_dict)
|
||||
@@ -27,7 +27,8 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from .modeling_utils import WEIGHTS_NAME
|
||||
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
from .modeling_utils import WEIGHTS_NAME, load_state_dict
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
||||
|
||||
|
||||
@@ -245,6 +246,8 @@ class FlaxModelMixin:
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
from_pt (`bool`, *optional*, defaults to `False`):
|
||||
Load the model weights from a PyTorch checkpoint save file.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
@@ -272,6 +275,7 @@ class FlaxModelMixin:
|
||||
config = kwargs.pop("config", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
from_pt = kwargs.pop("from_pt", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
@@ -306,10 +310,16 @@ class FlaxModelMixin:
|
||||
# Load from a Flax checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif from_pt:
|
||||
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
)
|
||||
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
"but there is a file for PyTorch weights."
|
||||
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
|
||||
" using `from_pt=True`."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
@@ -320,7 +330,7 @@ class FlaxModelMixin:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=FLAX_WEIGHTS_NAME,
|
||||
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -370,25 +380,32 @@ class FlaxModelMixin:
|
||||
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
|
||||
try:
|
||||
with open(model_file, "rb") as state_f:
|
||||
state = from_bytes(cls, state_f.read())
|
||||
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
||||
if from_pt:
|
||||
# Step 1: Get the pytorch file
|
||||
pytorch_model_file = load_state_dict(model_file)
|
||||
|
||||
# Step 2: Convert the weights
|
||||
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
|
||||
else:
|
||||
try:
|
||||
with open(model_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please"
|
||||
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
||||
" folder you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
||||
# make sure all arrays are stored as jnp.ndarray
|
||||
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||
# https://github.com/google/flax/issues/1261
|
||||
with open(model_file, "rb") as state_f:
|
||||
state = from_bytes(cls, state_f.read())
|
||||
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
||||
try:
|
||||
with open(model_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please"
|
||||
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
||||
" folder you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
||||
# make sure all arrays are stored as jnp.ndarray
|
||||
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||
# https://github.com/google/flax/issues/1261
|
||||
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
||||
|
||||
# flatten dicts
|
||||
|
||||
@@ -32,7 +32,7 @@ class FlaxAttentionBlock(nn.Module):
|
||||
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
|
||||
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
||||
|
||||
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
|
||||
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
@@ -82,9 +82,9 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
|
||||
def setup(self):
|
||||
# self attention
|
||||
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
# cross attention
|
||||
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
|
||||
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
@@ -93,12 +93,12 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
# self attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
|
||||
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# cross attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
|
||||
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# feed forward
|
||||
@@ -168,13 +168,27 @@ class FlaxGluFeedForward(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
inner_dim = self.dim * 4
|
||||
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
||||
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
|
||||
# The second linear layer needs to be called
|
||||
# net_2 for now to match the index of the Sequential layer
|
||||
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
|
||||
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.dense1(hidden_states)
|
||||
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
||||
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
|
||||
hidden_states = self.dense2(hidden_states)
|
||||
hidden_states = self.net_0(hidden_states)
|
||||
hidden_states = self.net_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxGEGLU(nn.Module):
|
||||
dim: int
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
inner_dim = self.dim * 4
|
||||
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
||||
return hidden_linear * nn.gelu(hidden_gelu)
|
||||
|
||||
@@ -76,7 +76,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
||||
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
||||
@@ -214,10 +214,17 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. time
|
||||
if not isinstance(timesteps, jnp.ndarray):
|
||||
timesteps = jnp.array([timesteps], dtype=jnp.int32)
|
||||
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps.astype(dtype=jnp.float32)
|
||||
timesteps = jnp.expand_dims(timesteps, 0)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
@@ -251,6 +258,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = nn.silu(sample)
|
||||
sample = self.conv_out(sample)
|
||||
sample = jnp.transpose(sample, (0, 3, 1, 2))
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
@@ -55,7 +55,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
self.attentions = attentions
|
||||
|
||||
if self.add_downsample:
|
||||
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
|
||||
output_states = ()
|
||||
@@ -66,7 +66,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.add_downsample:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
hidden_states = self.downsamplers_0(hidden_states)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
@@ -96,7 +96,7 @@ class FlaxDownBlock2D(nn.Module):
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_downsample:
|
||||
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, temb, deterministic=True):
|
||||
output_states = ()
|
||||
@@ -106,7 +106,7 @@ class FlaxDownBlock2D(nn.Module):
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.add_downsample:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
hidden_states = self.downsamplers_0(hidden_states)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
@@ -151,7 +151,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
self.attentions = attentions
|
||||
|
||||
if self.add_upsample:
|
||||
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
@@ -164,7 +164,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
||||
|
||||
if self.add_upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
hidden_states = self.upsamplers_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -196,7 +196,7 @@ class FlaxUpBlock2D(nn.Module):
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_upsample:
|
||||
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
|
||||
for resnet in self.resnets:
|
||||
@@ -208,7 +208,7 @@ class FlaxUpBlock2D(nn.Module):
|
||||
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
||||
|
||||
if self.add_upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
hidden_states = self.upsamplers_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user