1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Flax memory efficient attention (#2889)

* add use_memory_efficient params placeholder

* test

* add memory efficient attention jax

* add memory efficient attention jax

* newline

* forgot dot

* Rename use_memory_efficient

* Keep dtype last.

* Actually use key_chunk_size

* Rename symbol

* Apply style

* Rename use_memory_efficient

* Keep dtype last

* Pass `use_memory_efficient_attention` in `from_pretrained`

* Move JAX memory efficient attention to attention_flax.

* Simple test.

* style

---------

Co-authored-by: muhammad_hanif <muhammad_hanif@sofcograha.co.id>
Co-authored-by: MuhHanif <48muhhanif@gmail.com>
This commit is contained in:
Pedro Cuenca
2023-04-12 11:17:51 +02:00
committed by GitHub
parent 0df47efee2
commit dc277501c7
5 changed files with 216 additions and 9 deletions

View File

@@ -12,10 +12,110 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
import flax.linen as nn
import jax
import jax.numpy as jnp
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
num_kv, num_heads, k_features = key.shape[-3:]
v_features = value.shape[-1]
key_chunk_size = min(key_chunk_size, num_kv)
query = query / jnp.sqrt(k_features)
@functools.partial(jax.checkpoint, prevent_cse=False)
def summarize_chunk(query, key, value):
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
max_score = jax.lax.stop_gradient(max_score)
exp_weights = jnp.exp(attn_weights - max_score)
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
max_score = jnp.einsum("...qhk->...qh", max_score)
return (exp_values, exp_weights.sum(axis=-1), max_score)
def chunk_scanner(chunk_idx):
# julienne key array
key_chunk = jax.lax.dynamic_slice(
operand=key,
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
)
# julienne value array
value_chunk = jax.lax.dynamic_slice(
operand=value,
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
)
return summarize_chunk(query, key_chunk, value_chunk)
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
max_diffs = jnp.exp(chunk_max - global_max)
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
chunk_weights *= max_diffs
all_values = chunk_values.sum(axis=0)
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
return all_values / all_weights
def jax_memory_efficient_attention(
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):
r"""
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
https://github.com/AminRezaei0x443/memory-efficient-attention
Args:
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
numerical precision for computation
query_chunk_size (`int`, *optional*, defaults to 1024):
chunk size to divide query array value must divide query_length equally without remainder
key_chunk_size (`int`, *optional*, defaults to 4096):
chunk size to divide key and value array value must divide key_value_length equally without remainder
Returns:
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
"""
num_q, num_heads, q_features = query.shape[-3:]
def chunk_scanner(chunk_idx, _):
# julienne query array
query_chunk = jax.lax.dynamic_slice(
operand=query,
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
)
return (
chunk_idx + query_chunk_size, # unused ignore it
_query_chunk_attention(
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
),
)
_, res = jax.lax.scan(
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
)
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
class FlaxAttention(nn.Module):
r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
@@ -29,6 +129,8 @@ class FlaxAttention(nn.Module):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
@@ -37,6 +139,7 @@ class FlaxAttention(nn.Module):
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -77,13 +180,38 @@ class FlaxAttention(nn.Module):
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)
# compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)
if self.use_memory_efficient_attention:
query_states = query_states.transpose(1, 0, 2)
key_states = key_states.transpose(1, 0, 2)
value_states = value_states.transpose(1, 0, 2)
# this if statement create a chunk size for each layer of the unet
# the chunk size is equal to the query_length dimension of the deepest layer of the unet
flatten_latent_dim = query_states.shape[-3]
if flatten_latent_dim % 64 == 0:
query_chunk_size = int(flatten_latent_dim / 64)
elif flatten_latent_dim % 16 == 0:
query_chunk_size = int(flatten_latent_dim / 16)
elif flatten_latent_dim % 4 == 0:
query_chunk_size = int(flatten_latent_dim / 4)
else:
query_chunk_size = int(flatten_latent_dim)
hidden_states = jax_memory_efficient_attention(
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
)
hidden_states = hidden_states.transpose(1, 0, 2)
else:
# compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)
# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return hidden_states
@@ -108,6 +236,8 @@ class FlaxBasicTransformerBlock(nn.Module):
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
"""
dim: int
n_heads: int
@@ -115,12 +245,17 @@ class FlaxBasicTransformerBlock(nn.Module):
dropout: float = 0.0
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
)
# cross attention
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn2 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
@@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
"""
in_channels: int
n_heads: int
@@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module):
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
@@ -202,6 +340,7 @@ class FlaxTransformer2DModel(nn.Module):
dropout=self.dropout,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
use_memory_efficient_attention=self.use_memory_efficient_attention,
)
for _ in range(self.depth)
]

View File

@@ -37,6 +37,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -48,6 +50,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -72,6 +75,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
@@ -172,6 +176,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -184,6 +190,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -209,6 +216,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
@@ -311,6 +319,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -319,6 +329,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -341,6 +352,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
d_head=self.in_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
attentions.append(attn_block)

View File

@@ -88,6 +88,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
"""
@@ -111,6 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
@@ -169,6 +172,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
else:
@@ -190,6 +194,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
@@ -217,6 +222,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
dropout=self.dropout,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
)
else:

View File

@@ -296,6 +296,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
dtype = kwargs.pop("dtype", None)
# 1. Download the checkpoints and configs
@@ -451,7 +452,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model = cached_folder
if issubclass(class_obj, FlaxModelMixin):
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
loaded_sub_model, loaded_params = load_method(
loadable_folder,
from_pt=from_pt,
use_memory_efficient_attention=use_memory_efficient_attention,
dtype=dtype,
)
params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
if from_pt:

View File

@@ -215,3 +215,47 @@ class FlaxPipelineTests(unittest.TestCase):
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
def test_jax_memory_efficient_attention(self):
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16,
safety_checker=None,
)
params = replicate(params)
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
slice = images[2, 0, 256, 10:17, 1]
# With memory efficient attention
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16,
safety_checker=None,
use_memory_efficient_attention=True,
)
params = replicate(params)
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images
assert images_eff.shape == (num_samples, 1, 512, 512, 3)
slice_eff = images[2, 0, 256, 10:17, 1]
# I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum`
# over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now.
assert abs(slice_eff - slice).max() < 1e-2