1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[core] Freenoise memory improvements (#9262)

* update

* implement prompt interpolation

* make style

* resnet memory optimizations

* more memory optimizations; todo: refactor

* update

* update animatediff controlnet with latest changes

* refactor chunked inference changes

* remove print statements

* update

* chunk -> split

* remove changes from incorrect conflict resolution

* remove changes from incorrect conflict resolution

* add explanation of SplitInferenceModule

* update docs

* Revert "update docs"

This reverts commit c55a50a271.

* update docstring for freenoise split inference

* apply suggestions from review

* add tests

* apply suggestions from review
This commit is contained in:
Aryan
2024-09-06 12:51:20 +05:30
committed by GitHub
parent 5249a2666e
commit 6dfa49963c
5 changed files with 294 additions and 64 deletions

View File

@@ -1104,8 +1104,26 @@ class FreeNoiseTransformerBlock(nn.Module):
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights
hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# TODO(aryan): Maybe this could be done in a better way.
#
# Previously, this was:
# hidden_states = torch.where(
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# )
#
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
# looked into this deeply because other memory optimizations led to more pronounced reductions.
hidden_states = torch.cat(
[
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
for accumulated_split, num_times_split in zip(
accumulated_values.split(self.context_length, dim=1),
num_times_accumulated.split(self.context_length, dim=1),
)
],
dim=1,
).to(dtype)
# 3. Feed-forward

View File

@@ -187,12 +187,12 @@ class AnimateDiffTransformer3D(nn.Module):
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(input=hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
@@ -200,7 +200,7 @@ class AnimateDiffTransformer3D(nn.Module):
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(input=hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
@@ -344,7 +344,7 @@ class DownBlockMotion(nn.Module):
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
@@ -352,7 +352,7 @@ class DownBlockMotion(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)
output_states = output_states + (hidden_states,)
@@ -531,25 +531,18 @@ class CrossAttnDownBlockMotion(nn.Module):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
@@ -563,7 +556,7 @@ class CrossAttnDownBlockMotion(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)
output_states = output_states + (hidden_states,)
@@ -757,25 +750,18 @@ class CrossAttnUpBlockMotion(nn.Module):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
@@ -783,7 +769,7 @@ class CrossAttnUpBlockMotion(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
return hidden_states
@@ -929,13 +915,13 @@ class UpBlockMotion(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
return hidden_states
@@ -1080,10 +1066,19 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
@@ -1096,14 +1091,6 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
@@ -1117,19 +1104,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
return hidden_states

View File

@@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ..models.transformers.transformer_2d import Transformer2DModel
from ..models.unets.unet_motion_model import (
AnimateDiffTransformer3D,
CrossAttnDownBlockMotion,
DownBlockMotion,
UpBlockMotion,
@@ -30,6 +34,114 @@ from ..utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SplitInferenceModule(nn.Module):
r"""
A wrapper module class that splits inputs along a specified dimension before performing a forward pass.
This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
them into smaller chunks, processing each chunk separately, and then reassembling the results.
Args:
module (`nn.Module`):
The underlying PyTorch module that will be applied to each chunk of split inputs.
split_size (`int`, defaults to `1`):
The size of each chunk after splitting the input tensor.
split_dim (`int`, defaults to `0`):
The dimension along which the input tensors are split.
input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`):
A list of keyword arguments (strings) that represent the input tensors to be split.
Workflow:
1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
`torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
that were passed.
3. The output tensors from each split are concatenated back together along `split_dim` before returning.
Example:
```python
>>> import torch
>>> import torch.nn as nn
>>> model = nn.Linear(1000, 1000)
>>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"])
>>> input_tensor = torch.randn(42, 1000)
>>> # Will split the tensor into 21 slices of shape [2, 1000].
>>> output = split_module(input=input_tensor)
```
It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex
multi-dimensional splitting.
"""
def __init__(
self,
module: nn.Module,
split_size: int = 1,
split_dim: int = 0,
input_kwargs_to_split: List[str] = ["hidden_states"],
) -> None:
super().__init__()
self.module = module
self.split_size = split_size
self.split_dim = split_dim
self.input_kwargs_to_split = set(input_kwargs_to_split)
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
r"""Forward method for the `SplitInferenceModule`.
This method processes the input by splitting specified keyword arguments along a given dimension, running the
underlying module on each split, and then concatenating the results. The splitting is controlled by the
`split_size` and `split_dim` parameters specified during initialization.
Args:
*args (`Any`):
Positional arguments that are passed directly to the `module` without modification.
**kwargs (`Dict[str, torch.Tensor]`):
Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
arguments are passed unchanged.
Returns:
`Union[torch.Tensor, Tuple[torch.Tensor]]`:
The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
without it.
- If the underlying module returns a single tensor, the result will be a single concatenated tensor
along the same `split_dim` after processing all splits.
- If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
"""
split_inputs = {}
# 1. Split inputs that were specified during initialization and also present in passed kwargs
for key in list(kwargs.keys()):
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
continue
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
kwargs.pop(key)
# 2. Invoke forward pass across each split
results = []
for split_input in zip(*split_inputs.values()):
inputs = dict(zip(split_inputs.keys(), split_input))
inputs.update(kwargs)
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
results.append(intermediate_tensor_or_tensor_tuple)
# 3. Concatenate split restuls to obtain final outputs
if isinstance(results[0], torch.Tensor):
return torch.cat(results, dim=self.split_dim)
elif isinstance(results[0], tuple):
return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)])
else:
raise ValueError(
"In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
)
class AnimateDiffFreeNoiseMixin:
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
@@ -70,6 +182,9 @@ class AnimateDiffFreeNoiseMixin:
motion_module.transformer_blocks[i].load_state_dict(
basic_transfomer_block.state_dict(), strict=True
)
motion_module.transformer_blocks[i].set_chunk_feed_forward(
basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
)
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to disable FreeNoise in transformer blocks."""
@@ -98,6 +213,9 @@ class AnimateDiffFreeNoiseMixin:
motion_module.transformer_blocks[i].load_state_dict(
free_noise_transfomer_block.state_dict(), strict=True
)
motion_module.transformer_blocks[i].set_chunk_feed_forward(
free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
)
def _check_inputs_free_noise(
self,
@@ -410,6 +528,69 @@ class AnimateDiffFreeNoiseMixin:
for block in blocks:
self._disable_free_noise_in_block(block)
def _enable_split_inference_motion_modules_(
self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
) -> None:
for motion_module in motion_modules:
motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
for i in range(len(motion_module.transformer_blocks)):
motion_module.transformer_blocks[i] = SplitInferenceModule(
motion_module.transformer_blocks[i],
spatial_split_size,
0,
["hidden_states", "encoder_hidden_states"],
)
motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
def _enable_split_inference_attentions_(
self, attentions: List[Transformer2DModel], temporal_split_size: int
) -> None:
for i in range(len(attentions)):
attentions[i] = SplitInferenceModule(
attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"]
)
def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None:
for i in range(len(resnets)):
resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"])
def _enable_split_inference_samplers_(
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int
) -> None:
for i in range(len(samplers)):
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])
def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
r"""
Enable FreeNoise memory optimizations by utilizing
[`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks.
Args:
spatial_split_size (`int`, defaults to `256`):
The split size across spatial dimensions for internal blocks. This is used in facilitating split
inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion
modeling blocks.
temporal_split_size (`int`, defaults to `16`):
The split size across temporal dimensions for internal blocks. This is used in facilitating split
inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial
attention, resnets, downsampling and upsampling blocks.
"""
# TODO(aryan): Discuss on what's the best way to provide more control to users
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
if getattr(block, "motion_modules", None) is not None:
self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size)
if getattr(block, "attentions", None) is not None:
self._enable_split_inference_attentions_(block.attentions, temporal_split_size)
if getattr(block, "resnets", None) is not None:
self._enable_split_inference_resnets_(block.resnets, temporal_split_size)
if getattr(block, "downsamplers", None) is not None:
self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size)
if getattr(block, "upsamplers", None) is not None:
self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size)
@property
def free_noise_enabled(self):
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None

View File

@@ -460,6 +460,30 @@ class AnimateDiffPipelineFastTests(
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
)
def test_free_noise_split_inference(self):
components = self.get_dummy_components()
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_noise(8, 4)
inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]
# Test FreeNoise with split inference memory-optimization
pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4)
inputs_enable_split_inference = self.get_dummy_inputs(torch_device)
frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0]
sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum()
self.assertLess(
sum_split_inference,
1e-4,
"Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results",
)
def test_free_noise_multi_prompt(self):
components = self.get_dummy_components()
pipe: AnimateDiffPipeline = self.pipeline_class(**components)

View File

@@ -492,6 +492,34 @@ class AnimateDiffVideoToVideoPipelineFastTests(
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
)
def test_free_noise_split_inference(self):
components = self.get_dummy_components()
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_noise(8, 4)
inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16)
inputs_normal["num_inference_steps"] = 2
inputs_normal["strength"] = 0.5
frames_normal = pipe(**inputs_normal).frames[0]
# Test FreeNoise with split inference memory-optimization
pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4)
inputs_enable_split_inference = self.get_dummy_inputs(torch_device, num_frames=16)
inputs_enable_split_inference["num_inference_steps"] = 2
inputs_enable_split_inference["strength"] = 0.5
frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0]
sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum()
self.assertLess(
sum_split_inference,
1e-4,
"Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results",
)
def test_free_noise_multi_prompt(self):
components = self.get_dummy_components()
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)