mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[core] Freenoise memory improvements (#9262)
* update
* implement prompt interpolation
* make style
* resnet memory optimizations
* more memory optimizations; todo: refactor
* update
* update animatediff controlnet with latest changes
* refactor chunked inference changes
* remove print statements
* update
* chunk -> split
* remove changes from incorrect conflict resolution
* remove changes from incorrect conflict resolution
* add explanation of SplitInferenceModule
* update docs
* Revert "update docs"
This reverts commit c55a50a271.
* update docstring for freenoise split inference
* apply suggestions from review
* add tests
* apply suggestions from review
This commit is contained in:
@@ -1104,8 +1104,26 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
||||
num_times_accumulated[:, frame_start:frame_end] += weights
|
||||
|
||||
hidden_states = torch.where(
|
||||
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
||||
# TODO(aryan): Maybe this could be done in a better way.
|
||||
#
|
||||
# Previously, this was:
|
||||
# hidden_states = torch.where(
|
||||
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
||||
# )
|
||||
#
|
||||
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
|
||||
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
|
||||
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
|
||||
# looked into this deeply because other memory optimizations led to more pronounced reductions.
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
|
||||
for accumulated_split, num_times_split in zip(
|
||||
accumulated_values.split(self.context_length, dim=1),
|
||||
num_times_accumulated.split(self.context_length, dim=1),
|
||||
)
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
|
||||
# 3. Feed-forward
|
||||
|
||||
@@ -187,12 +187,12 @@ class AnimateDiffTransformer3D(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_in(input=hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
@@ -200,7 +200,7 @@ class AnimateDiffTransformer3D(nn.Module):
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = self.proj_out(input=hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states[None, None, :]
|
||||
.reshape(batch_size, height, width, num_frames, channel)
|
||||
@@ -344,7 +344,7 @@ class DownBlockMotion(nn.Module):
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
@@ -352,7 +352,7 @@ class DownBlockMotion(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states=hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -531,25 +531,18 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
@@ -563,7 +556,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states=hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -757,25 +750,18 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
@@ -783,7 +769,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -929,13 +915,13 @@ class UpBlockMotion(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1080,10 +1066,19 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)
|
||||
|
||||
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
|
||||
for attn, resnet, motion_module in blocks:
|
||||
hidden_states = attn(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
@@ -1096,14 +1091,6 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states,
|
||||
@@ -1117,19 +1104,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -12,12 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
|
||||
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ..models.transformers.transformer_2d import Transformer2DModel
|
||||
from ..models.unets.unet_motion_model import (
|
||||
AnimateDiffTransformer3D,
|
||||
CrossAttnDownBlockMotion,
|
||||
DownBlockMotion,
|
||||
UpBlockMotion,
|
||||
@@ -30,6 +34,114 @@ from ..utils.torch_utils import randn_tensor
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SplitInferenceModule(nn.Module):
|
||||
r"""
|
||||
A wrapper module class that splits inputs along a specified dimension before performing a forward pass.
|
||||
|
||||
This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
|
||||
them into smaller chunks, processing each chunk separately, and then reassembling the results.
|
||||
|
||||
Args:
|
||||
module (`nn.Module`):
|
||||
The underlying PyTorch module that will be applied to each chunk of split inputs.
|
||||
split_size (`int`, defaults to `1`):
|
||||
The size of each chunk after splitting the input tensor.
|
||||
split_dim (`int`, defaults to `0`):
|
||||
The dimension along which the input tensors are split.
|
||||
input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`):
|
||||
A list of keyword arguments (strings) that represent the input tensors to be split.
|
||||
|
||||
Workflow:
|
||||
1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
|
||||
`torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
|
||||
2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
|
||||
that were passed.
|
||||
3. The output tensors from each split are concatenated back together along `split_dim` before returning.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import torch.nn as nn
|
||||
|
||||
>>> model = nn.Linear(1000, 1000)
|
||||
>>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"])
|
||||
|
||||
>>> input_tensor = torch.randn(42, 1000)
|
||||
>>> # Will split the tensor into 21 slices of shape [2, 1000].
|
||||
>>> output = split_module(input=input_tensor)
|
||||
```
|
||||
|
||||
It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex
|
||||
multi-dimensional splitting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
split_size: int = 1,
|
||||
split_dim: int = 0,
|
||||
input_kwargs_to_split: List[str] = ["hidden_states"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.module = module
|
||||
self.split_size = split_size
|
||||
self.split_dim = split_dim
|
||||
self.input_kwargs_to_split = set(input_kwargs_to_split)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
||||
r"""Forward method for the `SplitInferenceModule`.
|
||||
|
||||
This method processes the input by splitting specified keyword arguments along a given dimension, running the
|
||||
underlying module on each split, and then concatenating the results. The splitting is controlled by the
|
||||
`split_size` and `split_dim` parameters specified during initialization.
|
||||
|
||||
Args:
|
||||
*args (`Any`):
|
||||
Positional arguments that are passed directly to the `module` without modification.
|
||||
**kwargs (`Dict[str, torch.Tensor]`):
|
||||
Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
|
||||
entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
|
||||
arguments are passed unchanged.
|
||||
|
||||
Returns:
|
||||
`Union[torch.Tensor, Tuple[torch.Tensor]]`:
|
||||
The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
|
||||
without it.
|
||||
- If the underlying module returns a single tensor, the result will be a single concatenated tensor
|
||||
along the same `split_dim` after processing all splits.
|
||||
- If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
|
||||
along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
|
||||
"""
|
||||
split_inputs = {}
|
||||
|
||||
# 1. Split inputs that were specified during initialization and also present in passed kwargs
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
|
||||
continue
|
||||
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
|
||||
kwargs.pop(key)
|
||||
|
||||
# 2. Invoke forward pass across each split
|
||||
results = []
|
||||
for split_input in zip(*split_inputs.values()):
|
||||
inputs = dict(zip(split_inputs.keys(), split_input))
|
||||
inputs.update(kwargs)
|
||||
|
||||
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
|
||||
results.append(intermediate_tensor_or_tensor_tuple)
|
||||
|
||||
# 3. Concatenate split restuls to obtain final outputs
|
||||
if isinstance(results[0], torch.Tensor):
|
||||
return torch.cat(results, dim=self.split_dim)
|
||||
elif isinstance(results[0], tuple):
|
||||
return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)])
|
||||
else:
|
||||
raise ValueError(
|
||||
"In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
|
||||
)
|
||||
|
||||
|
||||
class AnimateDiffFreeNoiseMixin:
|
||||
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
|
||||
|
||||
@@ -70,6 +182,9 @@ class AnimateDiffFreeNoiseMixin:
|
||||
motion_module.transformer_blocks[i].load_state_dict(
|
||||
basic_transfomer_block.state_dict(), strict=True
|
||||
)
|
||||
motion_module.transformer_blocks[i].set_chunk_feed_forward(
|
||||
basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
|
||||
)
|
||||
|
||||
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
||||
r"""Helper function to disable FreeNoise in transformer blocks."""
|
||||
@@ -98,6 +213,9 @@ class AnimateDiffFreeNoiseMixin:
|
||||
motion_module.transformer_blocks[i].load_state_dict(
|
||||
free_noise_transfomer_block.state_dict(), strict=True
|
||||
)
|
||||
motion_module.transformer_blocks[i].set_chunk_feed_forward(
|
||||
free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
|
||||
)
|
||||
|
||||
def _check_inputs_free_noise(
|
||||
self,
|
||||
@@ -410,6 +528,69 @@ class AnimateDiffFreeNoiseMixin:
|
||||
for block in blocks:
|
||||
self._disable_free_noise_in_block(block)
|
||||
|
||||
def _enable_split_inference_motion_modules_(
|
||||
self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
|
||||
) -> None:
|
||||
for motion_module in motion_modules:
|
||||
motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
|
||||
|
||||
for i in range(len(motion_module.transformer_blocks)):
|
||||
motion_module.transformer_blocks[i] = SplitInferenceModule(
|
||||
motion_module.transformer_blocks[i],
|
||||
spatial_split_size,
|
||||
0,
|
||||
["hidden_states", "encoder_hidden_states"],
|
||||
)
|
||||
|
||||
motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
|
||||
|
||||
def _enable_split_inference_attentions_(
|
||||
self, attentions: List[Transformer2DModel], temporal_split_size: int
|
||||
) -> None:
|
||||
for i in range(len(attentions)):
|
||||
attentions[i] = SplitInferenceModule(
|
||||
attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"]
|
||||
)
|
||||
|
||||
def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None:
|
||||
for i in range(len(resnets)):
|
||||
resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"])
|
||||
|
||||
def _enable_split_inference_samplers_(
|
||||
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int
|
||||
) -> None:
|
||||
for i in range(len(samplers)):
|
||||
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])
|
||||
|
||||
def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
|
||||
r"""
|
||||
Enable FreeNoise memory optimizations by utilizing
|
||||
[`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks.
|
||||
|
||||
Args:
|
||||
spatial_split_size (`int`, defaults to `256`):
|
||||
The split size across spatial dimensions for internal blocks. This is used in facilitating split
|
||||
inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion
|
||||
modeling blocks.
|
||||
temporal_split_size (`int`, defaults to `16`):
|
||||
The split size across temporal dimensions for internal blocks. This is used in facilitating split
|
||||
inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial
|
||||
attention, resnets, downsampling and upsampling blocks.
|
||||
"""
|
||||
# TODO(aryan): Discuss on what's the best way to provide more control to users
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
if getattr(block, "motion_modules", None) is not None:
|
||||
self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size)
|
||||
if getattr(block, "attentions", None) is not None:
|
||||
self._enable_split_inference_attentions_(block.attentions, temporal_split_size)
|
||||
if getattr(block, "resnets", None) is not None:
|
||||
self._enable_split_inference_resnets_(block.resnets, temporal_split_size)
|
||||
if getattr(block, "downsamplers", None) is not None:
|
||||
self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size)
|
||||
if getattr(block, "upsamplers", None) is not None:
|
||||
self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size)
|
||||
|
||||
@property
|
||||
def free_noise_enabled(self):
|
||||
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
|
||||
|
||||
@@ -460,6 +460,30 @@ class AnimateDiffPipelineFastTests(
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_split_inference(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipe.enable_free_noise(8, 4)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
# Test FreeNoise with split inference memory-optimization
|
||||
pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4)
|
||||
|
||||
inputs_enable_split_inference = self.get_dummy_inputs(torch_device)
|
||||
frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0]
|
||||
|
||||
sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum()
|
||||
self.assertLess(
|
||||
sum_split_inference,
|
||||
1e-4,
|
||||
"Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_multi_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
|
||||
|
||||
@@ -492,6 +492,34 @@ class AnimateDiffVideoToVideoPipelineFastTests(
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_split_inference(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipe.enable_free_noise(8, 4)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs_normal["num_inference_steps"] = 2
|
||||
inputs_normal["strength"] = 0.5
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
# Test FreeNoise with split inference memory-optimization
|
||||
pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4)
|
||||
|
||||
inputs_enable_split_inference = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs_enable_split_inference["num_inference_steps"] = 2
|
||||
inputs_enable_split_inference["strength"] = 0.5
|
||||
frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0]
|
||||
|
||||
sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum()
|
||||
self.assertLess(
|
||||
sum_split_inference,
|
||||
1e-4,
|
||||
"Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_multi_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
|
||||
|
||||
Reference in New Issue
Block a user