From 661a0b389d2cfb50be104c6a3ea2ab810d98c174 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Sep 2024 11:45:55 +0200 Subject: [PATCH] add explanation of SplitInferenceModule --- src/diffusers/pipelines/free_noise_utils.py | 53 +++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index b894a44192..fabd98062b 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -50,19 +50,65 @@ class SplitInferenceModule(nn.Module): self.input_kwargs_to_split = set(input_kwargs_to_split) def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: - r"""Forward method of `SplitInferenceModule`. + r"""Forward method for the `SplitInferenceModule`. - All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be - split that are specified in `inputs_to_split` when initializing the module. + This method processes the input by splitting specified keyword arguments along a given dimension, running the + underlying module on each split, and then concatenating the results. The splitting is controlled by the + `split_size` and `split_dim` parameters specified during initialization. + + Args: + *args (`Any`): + Positional arguments that are passed directly to the `module` without modification. + **kwargs (`Dict[str, torch.Tensor]`): + Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the + entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword + arguments are passed unchanged. + + Returns: + `Union[torch.Tensor, Tuple[torch.Tensor]]`: + The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred + without it. + - If the underlying module returns a single tensor, the result will be a single concatenated tensor + along the same `split_dim` after processing all splits. + - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated + along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors. + + Workflow: + 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using + `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`. + 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments + that were passed. + 3. The output tensors from each split are concatenated back together along `split_dim` before returning. + + Example: + ```python + >>> import torch + + >>> model = nn.Linear(1000, 1000) + >>> split_module = SplitInferenceModule( + ... model, split_size=2, split_dim=0, input_kwargs_to_split=["input_data"] + ... ) + + >>> input_tensor = torch.randn(42, 1000) + >>> # Will split the tensor into 21 slices of shape [2, 1000]. + >>> output = split_module(input_data=input_tensor) + ``` + + This method is useful when you need to perform inference on large tensors in a memory-efficient way by breaking + them into smaller chunks, processing each chunk separately, and then reassembling the results. + + It is also possible to nest `SplitInferenceModule` across different split dimensions. """ split_inputs = {} + # 1. Split inputs that were specified during initialization and also present in passed kwargs for key in list(kwargs.keys()): if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): continue split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) kwargs.pop(key) + # 2. Invoke forward pass across each split results = [] for split_input in zip(*split_inputs.values()): inputs = dict(zip(split_inputs.keys(), split_input)) @@ -71,6 +117,7 @@ class SplitInferenceModule(nn.Module): intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) results.append(intermediate_tensor_or_tensor_tuple) + # 3. Concatenate split restuls to obtain final outputs if isinstance(results[0], torch.Tensor): return torch.cat(results, dim=self.split_dim) elif isinstance(results[0], tuple):