mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Deprecate Flax support (#12151)
* start removing flax stuff. * add deprecation warning. * add warning messages. * more warnings. * remove dockerfiles. * remove more. * Update src/diffusers/models/attention_flax.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * up --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -19,6 +19,11 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
|
||||
"""Multi-head dot product attention with a limited number of queries."""
|
||||
@@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
inner_dim = self.dim_head * self.heads
|
||||
self.scale = self.dim_head**-0.5
|
||||
|
||||
@@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
split_head_dim: bool = False
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
# self attention (or cross_attention if only_cross_attention is True)
|
||||
self.attn1 = FlaxAttention(
|
||||
self.dim,
|
||||
@@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
split_head_dim: bool = False
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
|
||||
inner_dim = self.n_heads * self.d_head
|
||||
@@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
# The second linear layer needs to be called
|
||||
# net_2 for now to match the index of the Sequential layer
|
||||
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
|
||||
@@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
inner_dim = self.dim * 4
|
||||
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
@@ -20,7 +20,7 @@ import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from ..unets.unet_2d_blocks_flax import (
|
||||
@@ -30,6 +30,9 @@ from ..unets.unet_2d_blocks_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxControlNetOutput(BaseOutput):
|
||||
"""
|
||||
@@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self) -> None:
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv(
|
||||
self.block_out_channels[0],
|
||||
kernel_size=(3, 3),
|
||||
@@ -184,6 +192,11 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
|
||||
|
||||
def setup(self) -> None:
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
block_out_channels = self.block_out_channels
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
|
||||
@@ -16,6 +16,11 @@ import math
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_sinusoidal_embeddings(
|
||||
timesteps: jnp.ndarray,
|
||||
@@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
|
||||
The data type for the embedding parameters.
|
||||
"""
|
||||
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
time_embed_dim: int = 32
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
|
||||
flip_sin_to_cos: bool = False
|
||||
freq_shift: float = 1
|
||||
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, timesteps):
|
||||
return get_sinusoidal_embeddings(
|
||||
|
||||
@@ -290,6 +290,10 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
config = kwargs.pop("config", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
|
||||
@@ -15,12 +15,22 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FlaxUpsample2D(nn.Module):
|
||||
out_channels: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.conv = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.conv = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
|
||||
@@ -15,10 +15,14 @@
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ...utils import logging
|
||||
from ..attention_flax import FlaxTransformer2DModel
|
||||
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
r"""
|
||||
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
|
||||
@@ -60,6 +64,11 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
@@ -135,6 +144,11 @@ class FlaxDownBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
@@ -208,6 +222,11 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
@@ -288,6 +307,11 @@ class FlaxUpBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
@@ -356,6 +380,11 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
FlaxResnetBlock2D(
|
||||
|
||||
@@ -20,7 +20,7 @@ import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from .unet_2d_blocks_flax import (
|
||||
@@ -32,6 +32,9 @@ from .unet_2d_blocks_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxUNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
@@ -163,6 +166,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
|
||||
|
||||
def setup(self) -> None:
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
block_out_channels = self.block_out_channels
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
|
||||
@@ -25,10 +25,13 @@ import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, logging
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxDecoderOutput(BaseOutput):
|
||||
"""
|
||||
@@ -73,6 +76,10 @@ class FlaxUpsample2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.conv = nn.Conv(
|
||||
self.in_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -107,6 +114,11 @@ class FlaxDownsample2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.conv = nn.Conv(
|
||||
self.in_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -149,6 +161,11 @@ class FlaxResnetBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
||||
@@ -221,6 +238,11 @@ class FlaxAttentionBlock(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
|
||||
|
||||
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
|
||||
@@ -302,6 +324,11 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
@@ -359,6 +386,11 @@ class FlaxUpDecoderBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
@@ -413,6 +445,11 @@ class FlaxUNetMidBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
@@ -504,6 +541,11 @@ class FlaxEncoder(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
block_out_channels = self.block_out_channels
|
||||
# in
|
||||
self.conv_in = nn.Conv(
|
||||
@@ -616,6 +658,11 @@ class FlaxDecoder(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
block_out_channels = self.block_out_channels
|
||||
|
||||
# z to block_in
|
||||
@@ -788,6 +835,11 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.encoder = FlaxEncoder(
|
||||
in_channels=self.config.in_channels,
|
||||
out_channels=self.config.latent_channels,
|
||||
|
||||
@@ -312,6 +312,11 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
>>> dpm_params["scheduler"] = dpmpp_state
|
||||
```
|
||||
"""
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
|
||||
@@ -22,9 +22,11 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import BaseOutput, PushToHubMixin
|
||||
from ..utils import BaseOutput, PushToHubMixin, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
@@ -133,6 +135,10 @@ class FlaxSchedulerMixin(PushToHubMixin):
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
config, kwargs = cls.load_config(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
|
||||
Reference in New Issue
Block a user