1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Deprecate Flax support (#12151)

* start removing flax stuff.

* add deprecation warning.

* add warning messages.

* more warnings.

* remove dockerfiles.

* remove more.

* Update src/diffusers/models/attention_flax.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* up

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Sayak Paul
2025-08-26 09:58:16 +02:00
committed by GitHub
parent 5fcd5f560f
commit 532f41c999
21 changed files with 186 additions and 1848 deletions

View File

@@ -19,6 +19,11 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from ..utils import logging
logger = logging.get_logger(__name__)
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
@@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5
@@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
split_head_dim: bool = False
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(
self.dim,
@@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
split_head_dim: bool = False
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
@@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
@@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)

View File

@@ -20,7 +20,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from ..unets.unet_2d_blocks_flax import (
@@ -30,6 +30,9 @@ from ..unets.unet_2d_blocks_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput):
"""
@@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv_in = nn.Conv(
self.block_out_channels[0],
kernel_size=(3, 3),
@@ -184,6 +192,11 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4

View File

@@ -16,6 +16,11 @@ import math
import flax.linen as nn
import jax.numpy as jnp
from ..utils import logging
logger = logging.get_logger(__name__)
def get_sinusoidal_embeddings(
timesteps: jnp.ndarray,
@@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
The data type for the embedding parameters.
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
flip_sin_to_cos: bool = False
freq_shift: float = 1
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(

View File

@@ -290,6 +290,10 @@ class FlaxModelMixin(PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)

View File

@@ -15,12 +15,22 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from ..utils import logging
logger = logging.get_logger(__name__)
class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
@@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
@@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)

View File

@@ -15,10 +15,14 @@
import flax.linen as nn
import jax.numpy as jnp
from ...utils import logging
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
logger = logging.get_logger(__name__)
class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
@@ -60,6 +64,11 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
attentions = []
@@ -135,6 +144,11 @@ class FlaxDownBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
@@ -208,6 +222,11 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
attentions = []
@@ -288,6 +307,11 @@ class FlaxUpBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
@@ -356,6 +380,11 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(

View File

@@ -20,7 +20,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import (
@@ -32,6 +32,9 @@ from .unet_2d_blocks_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput):
"""
@@ -163,6 +166,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4

View File

@@ -25,10 +25,13 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, logging
from .modeling_flax_utils import FlaxModelMixin
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxDecoderOutput(BaseOutput):
"""
@@ -73,6 +76,10 @@ class FlaxUpsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
@@ -107,6 +114,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
@@ -149,6 +161,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
@@ -221,6 +238,11 @@ class FlaxAttentionBlock(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
@@ -302,6 +324,11 @@ class FlaxDownEncoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
@@ -359,6 +386,11 @@ class FlaxUpDecoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
@@ -413,6 +445,11 @@ class FlaxUNetMidBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
# there is always at least one resnet
@@ -504,6 +541,11 @@ class FlaxEncoder(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
# in
self.conv_in = nn.Conv(
@@ -616,6 +658,11 @@ class FlaxDecoder(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
# z to block_in
@@ -788,6 +835,11 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.encoder = FlaxEncoder(
in_channels=self.config.in_channels,
out_channels=self.config.latent_channels,

View File

@@ -312,6 +312,11 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> dpm_params["scheduler"] = dpmpp_state
```
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)

View File

@@ -22,9 +22,11 @@ import flax
import jax.numpy as jnp
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import BaseOutput, PushToHubMixin
from ..utils import BaseOutput, PushToHubMixin, logging
logger = logging.get_logger(__name__)
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@@ -133,6 +135,10 @@ class FlaxSchedulerMixin(PushToHubMixin):
</Tip>
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,