1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Deprecate Flax support (#12151)

* start removing flax stuff.

* add deprecation warning.

* add warning messages.

* more warnings.

* remove dockerfiles.

* remove more.

* Update src/diffusers/models/attention_flax.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* up

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Sayak Paul
2025-08-26 09:58:16 +02:00
committed by GitHub
parent 5fcd5f560f
commit 532f41c999
21 changed files with 186 additions and 1848 deletions

View File

@@ -1,38 +0,0 @@
name: Run Flax dependency tests
on:
pull_request:
branches:
- main
paths:
- "src/diffusers/**.py"
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
check_flax_dependencies:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pip install --upgrade pip uv
python -m uv pip install -e .
python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2"
python -m uv pip install "flax>=0.4.1"
python -m uv pip install "jaxlib>=0.1.65"
python -m uv pip install pytest
- name: Check for soft dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
pytest tests/others/test_dependencies.py

View File

@@ -1,49 +0,0 @@
FROM ubuntu:20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
&& apt-get install -y software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa
RUN apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
libgl1 \
python3.10 \
python3-pip \
python3.10-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3.10 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3 -m uv pip install --upgrade --no-cache-dir \
clu \
"jax[cpu]>=0.2.16,!=0.3.2" \
"flax>=0.4.1" \
"jaxlib>=0.1.65" && \
python3 -m uv pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
Jinja2 \
librosa \
numpy==1.26.4 \
scipy \
tensorboard \
transformers \
hf_transfer
CMD ["/bin/bash"]

View File

@@ -1,51 +0,0 @@
FROM ubuntu:20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
&& apt-get install -y software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa
RUN apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
libgl1 \
python3.10 \
python3-pip \
python3.10-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3.10 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3 -m pip install --no-cache-dir \
"jax[tpu]>=0.2.16,!=0.3.2" \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
python3 -m uv pip install --upgrade --no-cache-dir \
clu \
"flax>=0.4.1" \
"jaxlib>=0.1.65" && \
python3 -m uv pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
Jinja2 \
librosa \
numpy==1.26.4 \
scipy \
tensorboard \
transformers \
hf_transfer
CMD ["/bin/bash"]

View File

@@ -19,6 +19,11 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from ..utils import logging
logger = logging.get_logger(__name__)
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
@@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5
@@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
split_head_dim: bool = False
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(
self.dim,
@@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
split_head_dim: bool = False
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
@@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
@@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)

View File

@@ -20,7 +20,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from ..unets.unet_2d_blocks_flax import (
@@ -30,6 +30,9 @@ from ..unets.unet_2d_blocks_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput):
"""
@@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv_in = nn.Conv(
self.block_out_channels[0],
kernel_size=(3, 3),
@@ -184,6 +192,11 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4

View File

@@ -16,6 +16,11 @@ import math
import flax.linen as nn
import jax.numpy as jnp
from ..utils import logging
logger = logging.get_logger(__name__)
def get_sinusoidal_embeddings(
timesteps: jnp.ndarray,
@@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
The data type for the embedding parameters.
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
flip_sin_to_cos: bool = False
freq_shift: float = 1
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(

View File

@@ -290,6 +290,10 @@ class FlaxModelMixin(PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)

View File

@@ -15,12 +15,22 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from ..utils import logging
logger = logging.get_logger(__name__)
class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
@@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
@@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)

View File

@@ -15,10 +15,14 @@
import flax.linen as nn
import jax.numpy as jnp
from ...utils import logging
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
logger = logging.get_logger(__name__)
class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
@@ -60,6 +64,11 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
attentions = []
@@ -135,6 +144,11 @@ class FlaxDownBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
@@ -208,6 +222,11 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
attentions = []
@@ -288,6 +307,11 @@ class FlaxUpBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
@@ -356,6 +380,11 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(

View File

@@ -20,7 +20,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import (
@@ -32,6 +32,9 @@ from .unet_2d_blocks_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput):
"""
@@ -163,6 +166,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4

View File

@@ -25,10 +25,13 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, logging
from .modeling_flax_utils import FlaxModelMixin
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxDecoderOutput(BaseOutput):
"""
@@ -73,6 +76,10 @@ class FlaxUpsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
@@ -107,6 +114,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
@@ -149,6 +161,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
@@ -221,6 +238,11 @@ class FlaxAttentionBlock(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
@@ -302,6 +324,11 @@ class FlaxDownEncoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
@@ -359,6 +386,11 @@ class FlaxUpDecoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
@@ -413,6 +445,11 @@ class FlaxUNetMidBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
# there is always at least one resnet
@@ -504,6 +541,11 @@ class FlaxEncoder(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
# in
self.conv_in = nn.Conv(
@@ -616,6 +658,11 @@ class FlaxDecoder(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
block_out_channels = self.block_out_channels
# z to block_in
@@ -788,6 +835,11 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32
def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.encoder = FlaxEncoder(
in_channels=self.config.in_channels,
out_channels=self.config.latent_channels,

View File

@@ -312,6 +312,11 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> dpm_params["scheduler"] = dpmpp_state
```
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)

View File

@@ -22,9 +22,11 @@ import flax
import jax.numpy as jnp
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import BaseOutput, PushToHubMixin
from ..utils import BaseOutput, PushToHubMixin, logging
logger = logging.get_logger(__name__)
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@@ -133,6 +135,10 @@ class FlaxSchedulerMixin(PushToHubMixin):
</Tip>
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,

View File

@@ -1,39 +0,0 @@
import unittest
from diffusers import FlaxAutoencoderKL
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
from ..test_modeling_common_flax import FlaxModelTesterMixin
if is_flax_available():
import jax
@require_flax
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
model_class = FlaxAutoencoderKL
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
prng_key = jax.random.PRNGKey(0)
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
return {"sample": image, "prng_key": prng_key}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

View File

@@ -1,66 +0,0 @@
import inspect
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
if is_flax_available():
import jax
@require_flax
class FlaxModelTesterMixin:
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
if has_kwarg_in_model_class and not has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if not has_kwarg_in_model_class and has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)

View File

@@ -1,104 +0,0 @@
import gc
import unittest
from parameterized import parameterized
from diffusers import FlaxUNet2DConditionModel
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
@slow
@require_flax
class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = jnp.bfloat16 if fp16 else jnp.float32
image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
return image
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
dtype = jnp.bfloat16 if fp16 else jnp.float32
revision = "bf16" if fp16 else None
model, params = FlaxUNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", dtype=dtype, revision=revision
)
return model, params
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
dtype = jnp.bfloat16 if fp16 else jnp.float32
hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
return hidden_states
@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
# fmt: on
]
)
def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
sample = model.apply(
{"params": params},
latents,
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=encoder_hidden_states,
).sample
assert sample.shape == latents.shape
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
@parameterized.expand(
[
# fmt: off
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
# fmt: on
]
)
def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
sample = model.apply(
{"params": params},
latents,
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=encoder_hidden_states,
).sample
assert sample.shape == latents.shape
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)

View File

@@ -1,127 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
from diffusers.utils import is_flax_available, load_image
from diffusers.utils.testing_utils import require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
@slow
@require_flax
class FlaxControlNetPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def test_canny(self):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
)
params["controlnet"] = controlnet_params
prompts = "bird"
num_samples = jax.device_count()
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
canny_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, jax.device_count())
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
images = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
jit=True,
).images
assert images.shape == (jax.device_count(), 1, 768, 512, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array(
[0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]
)
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
def test_pose(self):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
)
params["controlnet"] = controlnet_params
prompts = "Chef in the kitchen"
num_samples = jax.device_count()
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
pose_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png"
)
processed_image = pipe.prepare_image_inputs([pose_image] * num_samples)
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, jax.device_count())
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
images = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
jit=True,
).images
assert images.shape == (jax.device_count(), 1, 768, 512, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array(
[[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]
)
assert jnp.abs(output_slice - expected_slice).max() < 1e-2

View File

@@ -1,108 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import nightly, require_flax
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
@nightly
@require_flax
class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def test_stable_diffusion_flax(self):
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2",
variant="bf16",
dtype=jnp.bfloat16,
)
prompt = "A painting of a squirrel eating a burger"
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = sd_pipe.prepare_inputs(prompt)
params = replicate(params)
prompt_ids = shard(prompt_ids)
prng_seed = jax.random.PRNGKey(0)
prng_seed = jax.random.split(prng_seed, jax.device_count())
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
@nightly
@require_flax
class FlaxStableDiffusion2PipelineNightlyTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def test_stable_diffusion_dpm_flax(self):
model_id = "stabilityai/stable-diffusion-2"
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
variant="bf16",
dtype=jnp.bfloat16,
)
params["scheduler"] = scheduler_params
prompt = "A painting of a squirrel eating a burger"
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = sd_pipe.prepare_inputs(prompt)
params = replicate(params)
prompt_ids = shard(prompt_ids)
prng_seed = jax.random.PRNGKey(0)
prng_seed = jax.random.split(prng_seed, jax.device_count())
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
assert jnp.abs(output_slice - expected_slice).max() < 1e-2

View File

@@ -1,82 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from diffusers import FlaxStableDiffusionInpaintPipeline
from diffusers.utils import is_flax_available, load_image
from diffusers.utils.testing_utils import require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
@slow
@require_flax
class FlaxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
def test_stable_diffusion_inpaint_pipeline(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/sd2-inpaint/init_image.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
)
model_id = "xvjiarui/stable-diffusion-2-inpainting"
pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
init_image = num_samples * [init_image]
mask_image = num_samples * [mask_image]
prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
processed_masked_images = shard(processed_masked_images)
processed_masks = shard(processed_masks)
output = pipeline(
prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
)
images = output.images.reshape(num_samples, 512, 512, 3)
image_slice = images[0, 253:256, 253:256, -1]
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array(
[0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084]
)
assert jnp.abs(output_slice - expected_slice).max() < 1e-2

View File

@@ -1,260 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
@require_flax
class DownloadTests(unittest.TestCase):
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
_ = FlaxDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a PyTorch file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
assert not any(f.endswith(".bin") for f in files)
@slow
@require_flax
class FlaxPipelineTests(unittest.TestCase):
def test_dummy_all_tpus(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 4
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 64, 64, 3)
if jax.device_count() == 8:
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3
assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
assert len(images_pil) == num_samples
def test_stable_diffusion_v1_4(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16, safety_checker=None
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
scheduler = FlaxDDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
steps_offset=1,
)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
variant="bf16",
dtype=jnp.bfloat16,
scheduler=scheduler,
safety_checker=None,
)
scheduler_state = scheduler.create_state()
params["scheduler"] = scheduler_state
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 5e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
def test_jax_memory_efficient_attention(self):
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
variant="bf16",
dtype=jnp.bfloat16,
safety_checker=None,
)
params = replicate(params)
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
slice = images[2, 0, 256, 10:17, 1]
# With memory efficient attention
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
variant="bf16",
dtype=jnp.bfloat16,
safety_checker=None,
use_memory_efficient_attention=True,
)
params = replicate(params)
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images
assert images_eff.shape == (num_samples, 1, 512, 512, 3)
slice_eff = images[2, 0, 256, 10:17, 1]
# I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum`
# over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now.
assert abs(slice_eff - slice).max() < 1e-2

View File

@@ -1,920 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import tempfile
import unittest
from typing import Dict, List, Tuple
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
if is_flax_available():
import jax
import jax.numpy as jnp
from jax import random
jax_device = jax.default_backend()
@require_flax
class FlaxSchedulerCommonTest(unittest.TestCase):
scheduler_classes = ()
forward_default_kwargs = ()
@property
def dummy_sample(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
key1, key2 = random.split(random.PRNGKey(0))
sample = random.uniform(key1, (batch_size, num_channels, height, width))
return sample, key2
@property
def dummy_sample_deter(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
num_elems = batch_size * num_channels * height * width
sample = jnp.arange(num_elems)
sample = sample.reshape(num_channels, height, width, batch_size)
sample = sample / num_elems
return jnp.transpose(sample, (3, 0, 1, 2))
def get_scheduler_config(self):
raise NotImplementedError
def dummy_model(self):
def model(sample, t, *args):
return sample * t / (t + 1)
return model
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, key = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, key = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, key = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, key = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample
output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
return t.at[t != t].set(0)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, key = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs)
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
def test_deprecated_kwargs(self):
for scheduler_class in self.scheduler_classes:
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
if has_kwarg_in_model_class and not has_deprecated_kwarg:
raise ValueError(
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if not has_kwarg_in_model_class and has_deprecated_kwarg:
raise ValueError(
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@require_flax
class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
scheduler_classes = (FlaxDDPMScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"variance_type": "fixed_small",
"clip_sample": True,
}
config.update(**kwargs)
return config
def test_timesteps(self):
for timesteps in [1, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_variance_type(self):
for variance in ["fixed_small", "fixed_large", "other"]:
self.check_over_configs(variance_type=variance)
def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_time_indices(self):
for t in [0, 500, 999]:
self.check_over_forward(time_step=t)
def test_variance(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
num_trained_timesteps = len(scheduler)
model = self.dummy_model()
sample = self.dummy_sample_deter
key1, key2 = random.split(random.PRNGKey(0))
for t in reversed(range(num_trained_timesteps)):
# 1. predict noise residual
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
output = scheduler.step(state, residual, t, sample, key1)
pred_prev_sample = output.prev_sample
state = output.state
key1, key2 = random.split(key2)
# if t > 0:
# noise = self.dummy_sample_deter
# variance = scheduler.get_variance(t) ** (0.5) * noise
#
# sample = pred_prev_sample + variance
sample = pred_prev_sample
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
if jax_device == "tpu":
assert abs(result_sum - 255.0714) < 1e-2
assert abs(result_mean - 0.332124) < 1e-3
else:
assert abs(result_sum - 270.2) < 1e-1
assert abs(result_mean - 0.3519494) < 1e-3
@require_flax
class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
scheduler_classes = (FlaxDDIMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
}
config.update(**kwargs)
return config
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
key1, key2 = random.split(random.PRNGKey(0))
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
state = scheduler.set_timesteps(state, num_inference_steps)
for t in state.timesteps:
residual = model(sample, t)
output = scheduler.step(state, residual, t, sample)
sample = output.prev_sample
state = output.state
key1, key2 = random.split(key2)
return sample
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, _ = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, _ = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, _ = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
return t.at[t != t].set(0)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, _ = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, _ = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [100, 500, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_steps_offset(self):
for steps_offset in [0, 1]:
self.check_over_configs(steps_offset=steps_offset)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, 5)
assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all()
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [1, 10, 49]:
self.check_over_forward(time_step=t)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_variance(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
assert abs(result_sum - 172.0067) < 1e-2
assert abs(result_mean - 0.223967) < 1e-3
def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
if jax_device == "tpu":
assert abs(result_sum - 149.8409) < 1e-2
assert abs(result_mean - 0.1951) < 1e-3
else:
assert abs(result_sum - 149.8295) < 1e-2
assert abs(result_mean - 0.1951) < 1e-3
def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
if jax_device == "tpu":
pass
# FIXME: both result_sum and result_mean are nan on TPU
# assert jnp.isnan(result_sum)
# assert jnp.isnan(result_mean)
else:
assert abs(result_sum - 149.0784) < 1e-2
assert abs(result_mean - 0.1941) < 1e-3
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
@require_flax
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
scheduler_classes = (FlaxPNDMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample, _ = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals
state = state.replace(ets=dummy_past_residuals[:])
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals
new_state = new_state.replace(ets=dummy_past_residuals[:])
(prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
(new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical"
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@unittest.skip("Test not supported.")
def test_from_save_pretrained(self):
pass
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
return t.at[t != t].set(0)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, _ = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample, _ = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals (must be after setting timesteps)
scheduler.ets = dummy_past_residuals[:]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
# copy over dummy past residual (must be after setting timesteps)
new_state.replace(ets=dummy_past_residuals[:])
output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
for i, t in enumerate(state.prk_timesteps):
residual = model(sample, t)
sample, state = scheduler.step_prk(state, residual, t, sample)
for i, t in enumerate(state.plms_timesteps):
residual = model(sample, t)
sample, state = scheduler.step_plms(state, residual, t, sample)
return sample
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, _ = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
state = state.replace(ets=dummy_past_residuals[:])
output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs)
output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs)
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs)
output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs)
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_steps_offset(self):
for steps_offset in [0, 1]:
self.check_over_configs(steps_offset=steps_offset)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, 10, shape=())
assert jnp.equal(
state.timesteps,
jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
).all()
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [1, 5, 10]:
self.check_over_forward(time_step=t)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(num_inference_steps=num_inference_steps)
def test_pow_of_3_inference_steps(self):
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
num_inference_steps = 27
for scheduler_class in self.scheduler_classes:
sample, _ = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
# before power of 3 fix, would error on first step, so we only need to do two
for i, t in enumerate(state.prk_timesteps[:2]):
sample, state = scheduler.step_prk(state, residual, t, sample)
def test_inference_plms_no_past_residuals(self):
with self.assertRaises(ValueError):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
if jax_device == "tpu":
assert abs(result_sum - 198.1275) < 1e-2
assert abs(result_mean - 0.2580) < 1e-3
else:
assert abs(result_sum - 198.1318) < 1e-2
assert abs(result_mean - 0.2580) < 1e-3
def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
if jax_device == "tpu":
assert abs(result_sum - 186.83226) < 1e-2
assert abs(result_mean - 0.24327) < 1e-3
else:
assert abs(result_sum - 186.9466) < 1e-2
assert abs(result_mean - 0.24342) < 1e-3
def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
if jax_device == "tpu":
assert abs(result_sum - 186.83226) < 1e-2
assert abs(result_mean - 0.24327) < 1e-3
else:
assert abs(result_sum - 186.9482) < 1e-2
assert abs(result_mean - 0.2434) < 1e-3