mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Deprecate Flax support (#12151)
* start removing flax stuff. * add deprecation warning. * add warning messages. * more warnings. * remove dockerfiles. * remove more. * Update src/diffusers/models/attention_flax.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * up --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
38
.github/workflows/pr_flax_dependency_test.yml
vendored
38
.github/workflows/pr_flax_dependency_test.yml
vendored
@@ -1,38 +0,0 @@
|
||||
name: Run Flax dependency tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check_flax_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pip install --upgrade pip uv
|
||||
python -m uv pip install -e .
|
||||
python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2"
|
||||
python -m uv pip install "flax>=0.4.1"
|
||||
python -m uv pip install "jaxlib>=0.1.65"
|
||||
python -m uv pip install pytest
|
||||
- name: Check for soft dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
pytest tests/others/test_dependencies.py
|
||||
@@ -1,49 +0,0 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
&& apt-get install -y software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa
|
||||
|
||||
RUN apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3 -m uv pip install --upgrade --no-cache-dir \
|
||||
clu \
|
||||
"jax[cpu]>=0.2.16,!=0.3.2" \
|
||||
"flax>=0.4.1" \
|
||||
"jaxlib>=0.1.65" && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
hf_transfer
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,51 +0,0 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
&& apt-get install -y software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa
|
||||
|
||||
RUN apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
"jax[tpu]>=0.2.16,!=0.3.2" \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
|
||||
python3 -m uv pip install --upgrade --no-cache-dir \
|
||||
clu \
|
||||
"flax>=0.4.1" \
|
||||
"jaxlib>=0.1.65" && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
hf_transfer
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -19,6 +19,11 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
|
||||
"""Multi-head dot product attention with a limited number of queries."""
|
||||
@@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
inner_dim = self.dim_head * self.heads
|
||||
self.scale = self.dim_head**-0.5
|
||||
|
||||
@@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
split_head_dim: bool = False
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
# self attention (or cross_attention if only_cross_attention is True)
|
||||
self.attn1 = FlaxAttention(
|
||||
self.dim,
|
||||
@@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
split_head_dim: bool = False
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
|
||||
inner_dim = self.n_heads * self.d_head
|
||||
@@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
# The second linear layer needs to be called
|
||||
# net_2 for now to match the index of the Sequential layer
|
||||
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
|
||||
@@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
inner_dim = self.dim * 4
|
||||
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
||||
|
||||
@@ -20,7 +20,7 @@ import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from ..unets.unet_2d_blocks_flax import (
|
||||
@@ -30,6 +30,9 @@ from ..unets.unet_2d_blocks_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxControlNetOutput(BaseOutput):
|
||||
"""
|
||||
@@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self) -> None:
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv(
|
||||
self.block_out_channels[0],
|
||||
kernel_size=(3, 3),
|
||||
@@ -184,6 +192,11 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
|
||||
|
||||
def setup(self) -> None:
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
block_out_channels = self.block_out_channels
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
|
||||
@@ -16,6 +16,11 @@ import math
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_sinusoidal_embeddings(
|
||||
timesteps: jnp.ndarray,
|
||||
@@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
|
||||
The data type for the embedding parameters.
|
||||
"""
|
||||
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
time_embed_dim: int = 32
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
|
||||
flip_sin_to_cos: bool = False
|
||||
freq_shift: float = 1
|
||||
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, timesteps):
|
||||
return get_sinusoidal_embeddings(
|
||||
|
||||
@@ -290,6 +290,10 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
config = kwargs.pop("config", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
|
||||
@@ -15,12 +15,22 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FlaxUpsample2D(nn.Module):
|
||||
out_channels: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.conv = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.conv = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
|
||||
@@ -15,10 +15,14 @@
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ...utils import logging
|
||||
from ..attention_flax import FlaxTransformer2DModel
|
||||
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
r"""
|
||||
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
|
||||
@@ -60,6 +64,11 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
@@ -135,6 +144,11 @@ class FlaxDownBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
@@ -208,6 +222,11 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
@@ -288,6 +307,11 @@ class FlaxUpBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
|
||||
for i in range(self.num_layers):
|
||||
@@ -356,6 +380,11 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
FlaxResnetBlock2D(
|
||||
|
||||
@@ -20,7 +20,7 @@ import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from .unet_2d_blocks_flax import (
|
||||
@@ -32,6 +32,9 @@ from .unet_2d_blocks_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxUNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
@@ -163,6 +166,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
|
||||
|
||||
def setup(self) -> None:
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
block_out_channels = self.block_out_channels
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
|
||||
@@ -25,10 +25,13 @@ import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, logging
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxDecoderOutput(BaseOutput):
|
||||
"""
|
||||
@@ -73,6 +76,10 @@ class FlaxUpsample2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.conv = nn.Conv(
|
||||
self.in_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -107,6 +114,11 @@ class FlaxDownsample2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.conv = nn.Conv(
|
||||
self.in_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -149,6 +161,11 @@ class FlaxResnetBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
||||
@@ -221,6 +238,11 @@ class FlaxAttentionBlock(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
|
||||
|
||||
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
|
||||
@@ -302,6 +324,11 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
@@ -359,6 +386,11 @@ class FlaxUpDecoderBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnets = []
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
@@ -413,6 +445,11 @@ class FlaxUNetMidBlock2D(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
@@ -504,6 +541,11 @@ class FlaxEncoder(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
block_out_channels = self.block_out_channels
|
||||
# in
|
||||
self.conv_in = nn.Conv(
|
||||
@@ -616,6 +658,11 @@ class FlaxDecoder(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
block_out_channels = self.block_out_channels
|
||||
|
||||
# z to block_in
|
||||
@@ -788,6 +835,11 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
self.encoder = FlaxEncoder(
|
||||
in_channels=self.config.in_channels,
|
||||
out_channels=self.config.latent_channels,
|
||||
|
||||
@@ -312,6 +312,11 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
>>> dpm_params["scheduler"] = dpmpp_state
|
||||
```
|
||||
"""
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
|
||||
@@ -22,9 +22,11 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import BaseOutput, PushToHubMixin
|
||||
from ..utils import BaseOutput, PushToHubMixin, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
@@ -133,6 +135,10 @@ class FlaxSchedulerMixin(PushToHubMixin):
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
config, kwargs = cls.load_config(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import unittest
|
||||
|
||||
from diffusers import FlaxAutoencoderKL
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
from ..test_modeling_common_flax import FlaxModelTesterMixin
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
|
||||
model_class = FlaxAutoencoderKL
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
prng_key = jax.random.PRNGKey(0)
|
||||
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
|
||||
|
||||
return {"sample": image, "prng_key": prng_key}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -1,66 +0,0 @@
|
||||
import inspect
|
||||
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
def test_output(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
||||
jax.lax.stop_gradient(variables)
|
||||
|
||||
output = model.apply(variables, inputs_dict["sample"])
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
||||
jax.lax.stop_gradient(variables)
|
||||
|
||||
output = model.apply(variables, inputs_dict["sample"])
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_deprecated_kwargs(self):
|
||||
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
|
||||
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
|
||||
|
||||
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
|
||||
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
|
||||
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||
" [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
|
||||
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
|
||||
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
|
||||
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
@@ -1,104 +0,0 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import FlaxUNet2DConditionModel
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
|
||||
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
|
||||
dtype = jnp.bfloat16 if fp16 else jnp.float32
|
||||
image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
|
||||
return image
|
||||
|
||||
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
|
||||
dtype = jnp.bfloat16 if fp16 else jnp.float32
|
||||
revision = "bf16" if fp16 else None
|
||||
|
||||
model, params = FlaxUNet2DConditionModel.from_pretrained(
|
||||
model_id, subfolder="unet", dtype=dtype, revision=revision
|
||||
)
|
||||
return model, params
|
||||
|
||||
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
|
||||
dtype = jnp.bfloat16 if fp16 else jnp.float32
|
||||
hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
|
||||
return hidden_states
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
|
||||
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
|
||||
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
|
||||
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
|
||||
model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
sample = model.apply(
|
||||
{"params": params},
|
||||
latents,
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
|
||||
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
|
||||
|
||||
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
|
||||
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
|
||||
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
|
||||
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
|
||||
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
|
||||
model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
|
||||
|
||||
sample = model.apply(
|
||||
{"params": params},
|
||||
latents,
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
|
||||
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
|
||||
|
||||
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
|
||||
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
|
||||
@@ -1,127 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
|
||||
from diffusers.utils import is_flax_available, load_image
|
||||
from diffusers.utils.testing_utils import require_flax, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxControlNetPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
||||
"lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16
|
||||
)
|
||||
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
|
||||
)
|
||||
params["controlnet"] = controlnet_params
|
||||
|
||||
prompts = "bird"
|
||||
num_samples = jax.device_count()
|
||||
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
||||
|
||||
canny_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
||||
|
||||
rng = jax.random.PRNGKey(0)
|
||||
rng = jax.random.split(rng, jax.device_count())
|
||||
|
||||
p_params = replicate(params)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
processed_image = shard(processed_image)
|
||||
|
||||
images = pipe(
|
||||
prompt_ids=prompt_ids,
|
||||
image=processed_image,
|
||||
params=p_params,
|
||||
prng_seed=rng,
|
||||
num_inference_steps=50,
|
||||
jit=True,
|
||||
).images
|
||||
assert images.shape == (jax.device_count(), 1, 768, 512, 3)
|
||||
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array(
|
||||
[0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]
|
||||
)
|
||||
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
||||
|
||||
def test_pose(self):
|
||||
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
||||
"lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16
|
||||
)
|
||||
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
|
||||
)
|
||||
params["controlnet"] = controlnet_params
|
||||
|
||||
prompts = "Chef in the kitchen"
|
||||
num_samples = jax.device_count()
|
||||
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
||||
|
||||
pose_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png"
|
||||
)
|
||||
processed_image = pipe.prepare_image_inputs([pose_image] * num_samples)
|
||||
|
||||
rng = jax.random.PRNGKey(0)
|
||||
rng = jax.random.split(rng, jax.device_count())
|
||||
|
||||
p_params = replicate(params)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
processed_image = shard(processed_image)
|
||||
|
||||
images = pipe(
|
||||
prompt_ids=prompt_ids,
|
||||
image=processed_image,
|
||||
params=p_params,
|
||||
prng_seed=rng,
|
||||
num_inference_steps=50,
|
||||
jit=True,
|
||||
).images
|
||||
assert images.shape == (jax.device_count(), 1, 768, 512, 3)
|
||||
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array(
|
||||
[[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]
|
||||
)
|
||||
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
||||
@@ -1,108 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import nightly, require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
|
||||
|
||||
@nightly
|
||||
@require_flax
|
||||
class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
|
||||
def test_stable_diffusion_flax(self):
|
||||
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2",
|
||||
variant="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = sd_pipe.prepare_inputs(prompt)
|
||||
|
||||
params = replicate(params)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
|
||||
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
|
||||
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
|
||||
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
|
||||
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@nightly
|
||||
@require_flax
|
||||
class FlaxStableDiffusion2PipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
|
||||
def test_stable_diffusion_dpm_flax(self):
|
||||
model_id = "stabilityai/stable-diffusion-2"
|
||||
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=scheduler,
|
||||
variant="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
params["scheduler"] = scheduler_params
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = sd_pipe.prepare_inputs(prompt)
|
||||
|
||||
params = replicate(params)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
|
||||
images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
|
||||
assert images.shape == (jax.device_count(), 1, 768, 768, 3)
|
||||
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
|
||||
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
||||
@@ -1,82 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import FlaxStableDiffusionInpaintPipeline
|
||||
from diffusers.utils import is_flax_available, load_image
|
||||
from diffusers.utils.testing_utils import require_flax, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
|
||||
def test_stable_diffusion_inpaint_pipeline(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-inpaint/init_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
|
||||
)
|
||||
|
||||
model_id = "xvjiarui/stable-diffusion-2-inpainting"
|
||||
pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
|
||||
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
init_image = num_samples * [init_image]
|
||||
mask_image = num_samples * [mask_image]
|
||||
prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
prompt_ids = shard(prompt_ids)
|
||||
processed_masked_images = shard(processed_masked_images)
|
||||
processed_masks = shard(processed_masks)
|
||||
|
||||
output = pipeline(
|
||||
prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
|
||||
)
|
||||
|
||||
images = output.images.reshape(num_samples, 512, 512, 3)
|
||||
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array(
|
||||
[0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084]
|
||||
)
|
||||
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
||||
@@ -1,260 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
|
||||
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
|
||||
|
||||
|
||||
@require_flax
|
||||
class DownloadTests(unittest.TestCase):
|
||||
def test_download_only_pytorch(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
_ = FlaxDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a PyTorch file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
|
||||
assert not any(f.endswith(".bin") for f in files)
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxPipelineTests(unittest.TestCase):
|
||||
def test_dummy_all_tpus(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 4
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, num_samples)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (num_samples, 1, 64, 64, 3)
|
||||
if jax.device_count() == 8:
|
||||
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3
|
||||
assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1
|
||||
|
||||
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||
assert len(images_pil) == num_samples
|
||||
|
||||
def test_stable_diffusion_v1_4(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, num_samples)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
if jax.device_count() == 8:
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-2
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16, safety_checker=None
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, num_samples)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
if jax.device_count() == 8:
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, num_samples)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
if jax.device_count() == 8:
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
|
||||
scheduler = FlaxDDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
variant="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
)
|
||||
scheduler_state = scheduler.create_state()
|
||||
|
||||
params["scheduler"] = scheduler_state
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, num_samples)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
if jax.device_count() == 8:
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 5e-2
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
|
||||
|
||||
def test_jax_memory_efficient_attention(self):
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples)
|
||||
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
variant="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
params = replicate(params)
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
images = pipeline(prompt_ids, params, prng_seed, jit=True).images
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
slice = images[2, 0, 256, 10:17, 1]
|
||||
|
||||
# With memory efficient attention
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
variant="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
safety_checker=None,
|
||||
use_memory_efficient_attention=True,
|
||||
)
|
||||
|
||||
params = replicate(params)
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images
|
||||
assert images_eff.shape == (num_samples, 1, 512, 512, 3)
|
||||
slice_eff = images[2, 0, 256, 10:17, 1]
|
||||
|
||||
# I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum`
|
||||
# over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now.
|
||||
assert abs(slice_eff - slice).max() < 1e-2
|
||||
@@ -1,920 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
jax_device = jax.default_backend()
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_classes = ()
|
||||
forward_default_kwargs = ()
|
||||
|
||||
@property
|
||||
def dummy_sample(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
key1, key2 = random.split(random.PRNGKey(0))
|
||||
sample = random.uniform(key1, (batch_size, num_channels, height, width))
|
||||
|
||||
return sample, key2
|
||||
|
||||
@property
|
||||
def dummy_sample_deter(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
num_elems = batch_size * num_channels * height * width
|
||||
sample = jnp.arange(num_elems)
|
||||
sample = sample.reshape(num_channels, height, width, batch_size)
|
||||
sample = sample / num_elems
|
||||
return jnp.transpose(sample, (3, 0, 1, 2))
|
||||
|
||||
def get_scheduler_config(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def dummy_model(self):
|
||||
def model(sample, t, *args):
|
||||
return sample * t / (t + 1)
|
||||
|
||||
return model
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
return t.at[t != t].set(0)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, key = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs)
|
||||
|
||||
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
|
||||
|
||||
def test_deprecated_kwargs(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
|
||||
has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
|
||||
|
||||
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
|
||||
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
|
||||
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
||||
" [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
||||
raise ValueError(
|
||||
f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
|
||||
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
|
||||
f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
|
||||
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
scheduler_classes = (FlaxDDPMScheduler,)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"variance_type": "fixed_small",
|
||||
"clip_sample": True,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [1, 5, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_variance_type(self):
|
||||
for variance in ["fixed_small", "fixed_large", "other"]:
|
||||
self.check_over_configs(variance_type=variance)
|
||||
|
||||
def test_clip_sample(self):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [0, 500, 999]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_variance(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
num_trained_timesteps = len(scheduler)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
key1, key2 = random.split(random.PRNGKey(0))
|
||||
|
||||
for t in reversed(range(num_trained_timesteps)):
|
||||
# 1. predict noise residual
|
||||
residual = model(sample, t)
|
||||
|
||||
# 2. predict previous mean of sample x_t-1
|
||||
output = scheduler.step(state, residual, t, sample, key1)
|
||||
pred_prev_sample = output.prev_sample
|
||||
state = output.state
|
||||
key1, key2 = random.split(key2)
|
||||
|
||||
# if t > 0:
|
||||
# noise = self.dummy_sample_deter
|
||||
# variance = scheduler.get_variance(t) ** (0.5) * noise
|
||||
#
|
||||
# sample = pred_prev_sample + variance
|
||||
sample = pred_prev_sample
|
||||
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
if jax_device == "tpu":
|
||||
assert abs(result_sum - 255.0714) < 1e-2
|
||||
assert abs(result_mean - 0.332124) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum - 270.2) < 1e-1
|
||||
assert abs(result_mean - 0.3519494) < 1e-3
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
scheduler_classes = (FlaxDDIMScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 50),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
key1, key2 = random.split(random.PRNGKey(0))
|
||||
|
||||
num_inference_steps = 10
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
|
||||
for t in state.timesteps:
|
||||
residual = model(sample, t)
|
||||
output = scheduler.step(state, residual, t, sample)
|
||||
sample = output.prev_sample
|
||||
state = output.state
|
||||
key1, key2 = random.split(key2)
|
||||
|
||||
return sample
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
return t.at[t != t].set(0)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
|
||||
|
||||
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 500, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_steps_offset(self):
|
||||
for steps_offset in [0, 1]:
|
||||
self.check_over_configs(steps_offset=steps_offset)
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
state = scheduler.set_timesteps(state, 5)
|
||||
assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all()
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 10, 49]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_variance(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5
|
||||
assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
assert abs(result_sum - 172.0067) < 1e-2
|
||||
assert abs(result_mean - 0.223967) < 1e-3
|
||||
|
||||
def test_full_loop_with_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
if jax_device == "tpu":
|
||||
assert abs(result_sum - 149.8409) < 1e-2
|
||||
assert abs(result_mean - 0.1951) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum - 149.8295) < 1e-2
|
||||
assert abs(result_mean - 0.1951) < 1e-3
|
||||
|
||||
def test_full_loop_with_no_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
if jax_device == "tpu":
|
||||
pass
|
||||
# FIXME: both result_sum and result_mean are nan on TPU
|
||||
# assert jnp.isnan(result_sum)
|
||||
# assert jnp.isnan(result_mean)
|
||||
else:
|
||||
assert abs(result_sum - 149.0784) < 1e-2
|
||||
assert abs(result_mean - 0.1941) < 1e-3
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
scheduler_classes = (FlaxPNDMScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 50),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
# copy over dummy past residuals
|
||||
state = state.replace(ets=dummy_past_residuals[:])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
|
||||
# copy over dummy past residuals
|
||||
new_state = new_state.replace(ets=dummy_past_residuals[:])
|
||||
|
||||
(prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
|
||||
(new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
|
||||
|
||||
assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
|
||||
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
return t.at[t != t].set(0)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
|
||||
|
||||
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
# copy over dummy past residuals (must be after setting timesteps)
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
# copy over dummy past residual (must be after setting timesteps)
|
||||
new_state.replace(ets=dummy_past_residuals[:])
|
||||
|
||||
output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
|
||||
new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
|
||||
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
|
||||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
for i, t in enumerate(state.prk_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample, state = scheduler.step_prk(state, residual, t, sample)
|
||||
|
||||
for i, t in enumerate(state.plms_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample, state = scheduler.step_plms(state, residual, t, sample)
|
||||
|
||||
return sample
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# copy over dummy past residuals (must be done after set_timesteps)
|
||||
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
|
||||
state = state.replace(ets=dummy_past_residuals[:])
|
||||
|
||||
output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs)
|
||||
output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs)
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs)
|
||||
output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs)
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_steps_offset(self):
|
||||
for steps_offset in [0, 1]:
|
||||
self.check_over_configs(steps_offset=steps_offset)
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
state = scheduler.set_timesteps(state, 10, shape=())
|
||||
assert jnp.equal(
|
||||
state.timesteps,
|
||||
jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
|
||||
).all()
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_pow_of_3_inference_steps(self):
|
||||
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
|
||||
num_inference_steps = 27
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample, _ = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
# before power of 3 fix, would error on first step, so we only need to do two
|
||||
for i, t in enumerate(state.prk_timesteps[:2]):
|
||||
sample, state = scheduler.step_prk(state, residual, t, sample)
|
||||
|
||||
def test_inference_plms_no_past_residuals(self):
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
state = scheduler.create_state()
|
||||
|
||||
scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
if jax_device == "tpu":
|
||||
assert abs(result_sum - 198.1275) < 1e-2
|
||||
assert abs(result_mean - 0.2580) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum - 198.1318) < 1e-2
|
||||
assert abs(result_mean - 0.2580) < 1e-3
|
||||
|
||||
def test_full_loop_with_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
if jax_device == "tpu":
|
||||
assert abs(result_sum - 186.83226) < 1e-2
|
||||
assert abs(result_mean - 0.24327) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum - 186.9466) < 1e-2
|
||||
assert abs(result_mean - 0.24342) < 1e-3
|
||||
|
||||
def test_full_loop_with_no_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
|
||||
result_sum = jnp.sum(jnp.abs(sample))
|
||||
result_mean = jnp.mean(jnp.abs(sample))
|
||||
|
||||
if jax_device == "tpu":
|
||||
assert abs(result_sum - 186.83226) < 1e-2
|
||||
assert abs(result_mean - 0.24327) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum - 186.9482) < 1e-2
|
||||
assert abs(result_mean - 0.2434) < 1e-3
|
||||
Reference in New Issue
Block a user