1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

chunk -> split

This commit is contained in:
Aryan
2024-09-05 07:52:28 +02:00
parent 2cef5c72cb
commit fb96059eb7
3 changed files with 39 additions and 44 deletions

View File

@@ -1121,7 +1121,6 @@ class FreeNoiseTransformerBlock(nn.Module):
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states

View File

@@ -2372,8 +2372,6 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
# TODO: figure out a better way to do this
# hidden_states = torch.cat([attn.to_out[1](attn.to_out[0](x)) for x in hidden_states.split(4, dim=0)], dim=0)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

View File

@@ -34,50 +34,50 @@ from ..utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ChunkedInferenceModule(nn.Module):
class SplitInferenceModule(nn.Module):
def __init__(
self,
module: nn.Module,
chunk_size: int = 1,
chunk_dim: int = 0,
input_kwargs_to_chunk: List[str] = ["hidden_states"],
split_size: int = 1,
split_dim: int = 0,
input_kwargs_to_split: List[str] = ["hidden_states"],
) -> None:
super().__init__()
self.module = module
self.chunk_size = chunk_size
self.chunk_dim = chunk_dim
self.input_kwargs_to_chunk = set(input_kwargs_to_chunk)
self.split_size = split_size
self.split_dim = split_dim
self.input_kwargs_to_split = set(input_kwargs_to_split)
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
r"""Forward method of `ChunkedInferenceModule`.
r"""Forward method of `SplitInferenceModule`.
All inputs that should be chunked should be passed as keyword arguments. Only those keywords arguments will be
chunked that are specified in `inputs_to_chunk` when initializing the module.
All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be
split that are specified in `inputs_to_split` when initializing the module.
"""
chunked_inputs = {}
split_inputs = {}
for key in list(kwargs.keys()):
if key not in self.input_kwargs_to_chunk or not torch.is_tensor(kwargs[key]):
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
continue
chunked_inputs[key] = torch.split(kwargs[key], self.chunk_size, self.chunk_dim)
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
kwargs.pop(key)
results = []
for chunked_input in zip(*chunked_inputs.values()):
inputs = dict(zip(chunked_inputs.keys(), chunked_input))
for split_input in zip(*split_inputs.values()):
inputs = dict(zip(split_inputs.keys(), split_input))
inputs.update(kwargs)
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
results.append(intermediate_tensor_or_tensor_tuple)
if isinstance(results[0], torch.Tensor):
return torch.cat(results, dim=self.chunk_dim)
return torch.cat(results, dim=self.split_dim)
elif isinstance(results[0], tuple):
return tuple([torch.cat(x, dim=self.chunk_dim) for x in zip(*results)])
return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)])
else:
raise ValueError(
"In order to use the ChunkedInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
"In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
)
@@ -603,55 +603,53 @@ class AnimateDiffFreeNoiseMixin:
for block in blocks:
self._disable_free_noise_in_block(block)
def _enable_chunked_inference_motion_modules_(
self, motion_modules: List[AnimateDiffTransformer3D], spatial_chunk_size: int
def _enable_split_inference_motion_modules_(
self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
) -> None:
for motion_module in motion_modules:
motion_module.proj_in = ChunkedInferenceModule(motion_module.proj_in, spatial_chunk_size, 0, ["input"])
motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
for i in range(len(motion_module.transformer_blocks)):
motion_module.transformer_blocks[i] = ChunkedInferenceModule(
motion_module.transformer_blocks[i] = SplitInferenceModule(
motion_module.transformer_blocks[i],
spatial_chunk_size,
spatial_split_size,
0,
["hidden_states", "encoder_hidden_states"],
)
motion_module.proj_out = ChunkedInferenceModule(motion_module.proj_out, spatial_chunk_size, 0, ["input"])
motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
def _enable_chunked_inference_attentions_(
self, attentions: List[Transformer2DModel], temporal_chunk_size: int
def _enable_split_inference_attentions_(
self, attentions: List[Transformer2DModel], temporal_split_size: int
) -> None:
for i in range(len(attentions)):
attentions[i] = ChunkedInferenceModule(
attentions[i], temporal_chunk_size, 0, ["hidden_states", "encoder_hidden_states"]
attentions[i] = SplitInferenceModule(
attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"]
)
def _enable_chunked_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_chunk_size: int) -> None:
def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None:
for i in range(len(resnets)):
resnets[i] = ChunkedInferenceModule(resnets[i], temporal_chunk_size, 0, ["input_tensor", "temb"])
resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"])
def _enable_chunked_inference_samplers_(
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_chunk_size: int
def _enable_split_inference_samplers_(
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int
) -> None:
for i in range(len(samplers)):
samplers[i] = ChunkedInferenceModule(samplers[i], temporal_chunk_size, 0, ["hidden_states"])
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])
def enable_free_noise_chunked_inference(
self, spatial_chunk_size: int = 256, temporal_chunk_size: int = 16
) -> None:
def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
if getattr(block, "motion_modules", None) is not None:
self._enable_chunked_inference_motion_modules_(block.motion_modules, spatial_chunk_size)
self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size)
if getattr(block, "attentions", None) is not None:
self._enable_chunked_inference_attentions_(block.attentions, temporal_chunk_size)
self._enable_split_inference_attentions_(block.attentions, temporal_split_size)
if getattr(block, "resnets", None) is not None:
self._enable_chunked_inference_resnets_(block.resnets, temporal_chunk_size)
self._enable_split_inference_resnets_(block.resnets, temporal_split_size)
if getattr(block, "downsamplers", None) is not None:
self._enable_chunked_inference_samplers_(block.downsamplers, temporal_chunk_size)
self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size)
if getattr(block, "upsamplers", None) is not None:
self._enable_chunked_inference_samplers_(block.upsamplers, temporal_chunk_size)
self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size)
@property
def free_noise_enabled(self):