mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add explanation of SplitInferenceModule
This commit is contained in:
@@ -50,19 +50,65 @@ class SplitInferenceModule(nn.Module):
|
||||
self.input_kwargs_to_split = set(input_kwargs_to_split)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
||||
r"""Forward method of `SplitInferenceModule`.
|
||||
r"""Forward method for the `SplitInferenceModule`.
|
||||
|
||||
All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be
|
||||
split that are specified in `inputs_to_split` when initializing the module.
|
||||
This method processes the input by splitting specified keyword arguments along a given dimension, running the
|
||||
underlying module on each split, and then concatenating the results. The splitting is controlled by the
|
||||
`split_size` and `split_dim` parameters specified during initialization.
|
||||
|
||||
Args:
|
||||
*args (`Any`):
|
||||
Positional arguments that are passed directly to the `module` without modification.
|
||||
**kwargs (`Dict[str, torch.Tensor]`):
|
||||
Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
|
||||
entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
|
||||
arguments are passed unchanged.
|
||||
|
||||
Returns:
|
||||
`Union[torch.Tensor, Tuple[torch.Tensor]]`:
|
||||
The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
|
||||
without it.
|
||||
- If the underlying module returns a single tensor, the result will be a single concatenated tensor
|
||||
along the same `split_dim` after processing all splits.
|
||||
- If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
|
||||
along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
|
||||
|
||||
Workflow:
|
||||
1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
|
||||
`torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
|
||||
2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
|
||||
that were passed.
|
||||
3. The output tensors from each split are concatenated back together along `split_dim` before returning.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
|
||||
>>> model = nn.Linear(1000, 1000)
|
||||
>>> split_module = SplitInferenceModule(
|
||||
... model, split_size=2, split_dim=0, input_kwargs_to_split=["input_data"]
|
||||
... )
|
||||
|
||||
>>> input_tensor = torch.randn(42, 1000)
|
||||
>>> # Will split the tensor into 21 slices of shape [2, 1000].
|
||||
>>> output = split_module(input_data=input_tensor)
|
||||
```
|
||||
|
||||
This method is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
|
||||
them into smaller chunks, processing each chunk separately, and then reassembling the results.
|
||||
|
||||
It is also possible to nest `SplitInferenceModule` across different split dimensions.
|
||||
"""
|
||||
split_inputs = {}
|
||||
|
||||
# 1. Split inputs that were specified during initialization and also present in passed kwargs
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
|
||||
continue
|
||||
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
|
||||
kwargs.pop(key)
|
||||
|
||||
# 2. Invoke forward pass across each split
|
||||
results = []
|
||||
for split_input in zip(*split_inputs.values()):
|
||||
inputs = dict(zip(split_inputs.keys(), split_input))
|
||||
@@ -71,6 +117,7 @@ class SplitInferenceModule(nn.Module):
|
||||
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
|
||||
results.append(intermediate_tensor_or_tensor_tuple)
|
||||
|
||||
# 3. Concatenate split restuls to obtain final outputs
|
||||
if isinstance(results[0], torch.Tensor):
|
||||
return torch.cat(results, dim=self.split_dim)
|
||||
elif isinstance(results[0], tuple):
|
||||
|
||||
Reference in New Issue
Block a user