mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Flax memory efficient attention (#2889)
* add use_memory_efficient params placeholder * test * add memory efficient attention jax * add memory efficient attention jax * newline * forgot dot * Rename use_memory_efficient * Keep dtype last. * Actually use key_chunk_size * Rename symbol * Apply style * Rename use_memory_efficient * Keep dtype last * Pass `use_memory_efficient_attention` in `from_pretrained` * Move JAX memory efficient attention to attention_flax. * Simple test. * style --------- Co-authored-by: muhammad_hanif <muhammad_hanif@sofcograha.co.id> Co-authored-by: MuhHanif <48muhhanif@gmail.com>
This commit is contained in:
@@ -12,10 +12,110 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
|
||||
"""Multi-head dot product attention with a limited number of queries."""
|
||||
num_kv, num_heads, k_features = key.shape[-3:]
|
||||
v_features = value.shape[-1]
|
||||
key_chunk_size = min(key_chunk_size, num_kv)
|
||||
query = query / jnp.sqrt(k_features)
|
||||
|
||||
@functools.partial(jax.checkpoint, prevent_cse=False)
|
||||
def summarize_chunk(query, key, value):
|
||||
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
|
||||
|
||||
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
||||
max_score = jax.lax.stop_gradient(max_score)
|
||||
exp_weights = jnp.exp(attn_weights - max_score)
|
||||
|
||||
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
|
||||
max_score = jnp.einsum("...qhk->...qh", max_score)
|
||||
|
||||
return (exp_values, exp_weights.sum(axis=-1), max_score)
|
||||
|
||||
def chunk_scanner(chunk_idx):
|
||||
# julienne key array
|
||||
key_chunk = jax.lax.dynamic_slice(
|
||||
operand=key,
|
||||
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
|
||||
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
|
||||
)
|
||||
|
||||
# julienne value array
|
||||
value_chunk = jax.lax.dynamic_slice(
|
||||
operand=value,
|
||||
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
|
||||
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
|
||||
)
|
||||
|
||||
return summarize_chunk(query, key_chunk, value_chunk)
|
||||
|
||||
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
|
||||
|
||||
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
|
||||
max_diffs = jnp.exp(chunk_max - global_max)
|
||||
|
||||
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
|
||||
chunk_weights *= max_diffs
|
||||
|
||||
all_values = chunk_values.sum(axis=0)
|
||||
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
|
||||
|
||||
return all_values / all_weights
|
||||
|
||||
|
||||
def jax_memory_efficient_attention(
|
||||
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
|
||||
):
|
||||
r"""
|
||||
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
|
||||
https://github.com/AminRezaei0x443/memory-efficient-attention
|
||||
|
||||
Args:
|
||||
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
|
||||
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
|
||||
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
|
||||
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
|
||||
numerical precision for computation
|
||||
query_chunk_size (`int`, *optional*, defaults to 1024):
|
||||
chunk size to divide query array value must divide query_length equally without remainder
|
||||
key_chunk_size (`int`, *optional*, defaults to 4096):
|
||||
chunk size to divide key and value array value must divide key_value_length equally without remainder
|
||||
|
||||
Returns:
|
||||
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
|
||||
"""
|
||||
num_q, num_heads, q_features = query.shape[-3:]
|
||||
|
||||
def chunk_scanner(chunk_idx, _):
|
||||
# julienne query array
|
||||
query_chunk = jax.lax.dynamic_slice(
|
||||
operand=query,
|
||||
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
|
||||
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
|
||||
)
|
||||
|
||||
return (
|
||||
chunk_idx + query_chunk_size, # unused ignore it
|
||||
_query_chunk_attention(
|
||||
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
|
||||
),
|
||||
)
|
||||
|
||||
_, res = jax.lax.scan(
|
||||
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
|
||||
)
|
||||
|
||||
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
|
||||
|
||||
|
||||
class FlaxAttention(nn.Module):
|
||||
r"""
|
||||
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
|
||||
@@ -29,6 +129,8 @@ class FlaxAttention(nn.Module):
|
||||
Hidden states dimension inside each head
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
|
||||
@@ -37,6 +139,7 @@ class FlaxAttention(nn.Module):
|
||||
heads: int = 8
|
||||
dim_head: int = 64
|
||||
dropout: float = 0.0
|
||||
use_memory_efficient_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -77,13 +180,38 @@ class FlaxAttention(nn.Module):
|
||||
key_states = self.reshape_heads_to_batch_dim(key_proj)
|
||||
value_states = self.reshape_heads_to_batch_dim(value_proj)
|
||||
|
||||
# compute attentions
|
||||
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
|
||||
attention_scores = attention_scores * self.scale
|
||||
attention_probs = nn.softmax(attention_scores, axis=2)
|
||||
if self.use_memory_efficient_attention:
|
||||
query_states = query_states.transpose(1, 0, 2)
|
||||
key_states = key_states.transpose(1, 0, 2)
|
||||
value_states = value_states.transpose(1, 0, 2)
|
||||
|
||||
# this if statement create a chunk size for each layer of the unet
|
||||
# the chunk size is equal to the query_length dimension of the deepest layer of the unet
|
||||
|
||||
flatten_latent_dim = query_states.shape[-3]
|
||||
if flatten_latent_dim % 64 == 0:
|
||||
query_chunk_size = int(flatten_latent_dim / 64)
|
||||
elif flatten_latent_dim % 16 == 0:
|
||||
query_chunk_size = int(flatten_latent_dim / 16)
|
||||
elif flatten_latent_dim % 4 == 0:
|
||||
query_chunk_size = int(flatten_latent_dim / 4)
|
||||
else:
|
||||
query_chunk_size = int(flatten_latent_dim)
|
||||
|
||||
hidden_states = jax_memory_efficient_attention(
|
||||
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 0, 2)
|
||||
else:
|
||||
# compute attentions
|
||||
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
|
||||
attention_scores = attention_scores * self.scale
|
||||
attention_probs = nn.softmax(attention_scores, axis=2)
|
||||
|
||||
# attend to values
|
||||
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
|
||||
|
||||
# attend to values
|
||||
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
return hidden_states
|
||||
@@ -108,6 +236,8 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
Whether to only apply cross attention.
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
"""
|
||||
dim: int
|
||||
n_heads: int
|
||||
@@ -115,12 +245,17 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
dropout: float = 0.0
|
||||
only_cross_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
use_memory_efficient_attention: bool = False
|
||||
|
||||
def setup(self):
|
||||
# self attention (or cross_attention if only_cross_attention is True)
|
||||
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
self.attn1 = FlaxAttention(
|
||||
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
|
||||
)
|
||||
# cross attention
|
||||
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
self.attn2 = FlaxAttention(
|
||||
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
|
||||
)
|
||||
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
|
||||
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
@@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
only_cross_attention (`bool`, defaults to `False`): tbd
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
"""
|
||||
in_channels: int
|
||||
n_heads: int
|
||||
@@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
use_linear_projection: bool = False
|
||||
only_cross_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
use_memory_efficient_attention: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
@@ -202,6 +340,7 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
dropout=self.dropout,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
dtype=self.dtype,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
|
||||
@@ -37,6 +37,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
Number of attention heads of each spatial transformer block
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsampling layer before each final output
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
@@ -48,6 +50,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
add_downsample: bool = True
|
||||
use_linear_projection: bool = False
|
||||
only_cross_attention: bool = False
|
||||
use_memory_efficient_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -72,6 +75,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
depth=1,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
@@ -172,6 +176,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
Number of attention heads of each spatial transformer block
|
||||
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add upsampling layer before each final output
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
@@ -184,6 +190,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
add_upsample: bool = True
|
||||
use_linear_projection: bool = False
|
||||
only_cross_attention: bool = False
|
||||
use_memory_efficient_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -209,6 +216,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
depth=1,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
@@ -311,6 +319,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
Number of attention blocks layers
|
||||
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of attention heads of each spatial transformer block
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
@@ -319,6 +329,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
num_layers: int = 1
|
||||
attn_num_head_channels: int = 1
|
||||
use_linear_projection: bool = False
|
||||
use_memory_efficient_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -341,6 +352,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
d_head=self.in_channels // self.attn_num_head_channels,
|
||||
depth=1,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
@@ -88,6 +88,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
|
||||
"""
|
||||
|
||||
@@ -111,6 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
flip_sin_to_cos: bool = True
|
||||
freq_shift: int = 0
|
||||
use_memory_efficient_attention: bool = False
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
# init input tensors
|
||||
@@ -169,6 +172,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
add_downsample=not is_final_block,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
@@ -190,6 +194,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
dropout=self.dropout,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -217,6 +222,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
dropout=self.dropout,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -296,6 +296,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_pt = kwargs.pop("from_pt", False)
|
||||
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
@@ -451,7 +452,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
loaded_sub_model = cached_folder
|
||||
|
||||
if issubclass(class_obj, FlaxModelMixin):
|
||||
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
||||
loaded_sub_model, loaded_params = load_method(
|
||||
loadable_folder,
|
||||
from_pt=from_pt,
|
||||
use_memory_efficient_attention=use_memory_efficient_attention,
|
||||
dtype=dtype,
|
||||
)
|
||||
params[name] = loaded_params
|
||||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||
if from_pt:
|
||||
|
||||
@@ -215,3 +215,47 @@ class FlaxPipelineTests(unittest.TestCase):
|
||||
if jax.device_count() == 8:
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
|
||||
|
||||
def test_jax_memory_efficient_attention(self):
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples)
|
||||
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
params = replicate(params)
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
images = pipeline(prompt_ids, params, prng_seed, jit=True).images
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
slice = images[2, 0, 256, 10:17, 1]
|
||||
|
||||
# With memory efficient attention
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
safety_checker=None,
|
||||
use_memory_efficient_attention=True,
|
||||
)
|
||||
|
||||
params = replicate(params)
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images
|
||||
assert images_eff.shape == (num_samples, 1, 512, 512, 3)
|
||||
slice_eff = images[2, 0, 256, 10:17, 1]
|
||||
|
||||
# I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum`
|
||||
# over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now.
|
||||
assert abs(slice_eff - slice).max() < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user