1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add explanation of SplitInferenceModule

This commit is contained in:
Aryan
2024-09-05 11:45:55 +02:00
parent 12f0ae11ba
commit 661a0b389d

View File

@@ -50,19 +50,65 @@ class SplitInferenceModule(nn.Module):
self.input_kwargs_to_split = set(input_kwargs_to_split)
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
r"""Forward method of `SplitInferenceModule`.
r"""Forward method for the `SplitInferenceModule`.
All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be
split that are specified in `inputs_to_split` when initializing the module.
This method processes the input by splitting specified keyword arguments along a given dimension, running the
underlying module on each split, and then concatenating the results. The splitting is controlled by the
`split_size` and `split_dim` parameters specified during initialization.
Args:
*args (`Any`):
Positional arguments that are passed directly to the `module` without modification.
**kwargs (`Dict[str, torch.Tensor]`):
Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
arguments are passed unchanged.
Returns:
`Union[torch.Tensor, Tuple[torch.Tensor]]`:
The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
without it.
- If the underlying module returns a single tensor, the result will be a single concatenated tensor
along the same `split_dim` after processing all splits.
- If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
Workflow:
1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
`torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
that were passed.
3. The output tensors from each split are concatenated back together along `split_dim` before returning.
Example:
```python
>>> import torch
>>> model = nn.Linear(1000, 1000)
>>> split_module = SplitInferenceModule(
... model, split_size=2, split_dim=0, input_kwargs_to_split=["input_data"]
... )
>>> input_tensor = torch.randn(42, 1000)
>>> # Will split the tensor into 21 slices of shape [2, 1000].
>>> output = split_module(input_data=input_tensor)
```
This method is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
them into smaller chunks, processing each chunk separately, and then reassembling the results.
It is also possible to nest `SplitInferenceModule` across different split dimensions.
"""
split_inputs = {}
# 1. Split inputs that were specified during initialization and also present in passed kwargs
for key in list(kwargs.keys()):
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
continue
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
kwargs.pop(key)
# 2. Invoke forward pass across each split
results = []
for split_input in zip(*split_inputs.values()):
inputs = dict(zip(split_inputs.keys(), split_input))
@@ -71,6 +117,7 @@ class SplitInferenceModule(nn.Module):
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
results.append(intermediate_tensor_or_tensor_tuple)
# 3. Concatenate split restuls to obtain final outputs
if isinstance(results[0], torch.Tensor):
return torch.cat(results, dim=self.split_dim)
elif isinstance(results[0], tuple):