mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
chunk -> split
This commit is contained in:
@@ -1121,7 +1121,6 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
if self._chunk_size is not None:
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
@@ -2372,8 +2372,6 @@ class AttnProcessor2_0:
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
# TODO: figure out a better way to do this
|
||||
# hidden_states = torch.cat([attn.to_out[1](attn.to_out[0](x)) for x in hidden_states.split(4, dim=0)], dim=0)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -34,50 +34,50 @@ from ..utils.torch_utils import randn_tensor
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChunkedInferenceModule(nn.Module):
|
||||
class SplitInferenceModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
chunk_size: int = 1,
|
||||
chunk_dim: int = 0,
|
||||
input_kwargs_to_chunk: List[str] = ["hidden_states"],
|
||||
split_size: int = 1,
|
||||
split_dim: int = 0,
|
||||
input_kwargs_to_split: List[str] = ["hidden_states"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.module = module
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_dim = chunk_dim
|
||||
self.input_kwargs_to_chunk = set(input_kwargs_to_chunk)
|
||||
self.split_size = split_size
|
||||
self.split_dim = split_dim
|
||||
self.input_kwargs_to_split = set(input_kwargs_to_split)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
||||
r"""Forward method of `ChunkedInferenceModule`.
|
||||
r"""Forward method of `SplitInferenceModule`.
|
||||
|
||||
All inputs that should be chunked should be passed as keyword arguments. Only those keywords arguments will be
|
||||
chunked that are specified in `inputs_to_chunk` when initializing the module.
|
||||
All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be
|
||||
split that are specified in `inputs_to_split` when initializing the module.
|
||||
"""
|
||||
chunked_inputs = {}
|
||||
split_inputs = {}
|
||||
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in self.input_kwargs_to_chunk or not torch.is_tensor(kwargs[key]):
|
||||
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
|
||||
continue
|
||||
chunked_inputs[key] = torch.split(kwargs[key], self.chunk_size, self.chunk_dim)
|
||||
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
|
||||
kwargs.pop(key)
|
||||
|
||||
results = []
|
||||
for chunked_input in zip(*chunked_inputs.values()):
|
||||
inputs = dict(zip(chunked_inputs.keys(), chunked_input))
|
||||
for split_input in zip(*split_inputs.values()):
|
||||
inputs = dict(zip(split_inputs.keys(), split_input))
|
||||
inputs.update(kwargs)
|
||||
|
||||
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
|
||||
results.append(intermediate_tensor_or_tensor_tuple)
|
||||
|
||||
if isinstance(results[0], torch.Tensor):
|
||||
return torch.cat(results, dim=self.chunk_dim)
|
||||
return torch.cat(results, dim=self.split_dim)
|
||||
elif isinstance(results[0], tuple):
|
||||
return tuple([torch.cat(x, dim=self.chunk_dim) for x in zip(*results)])
|
||||
return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)])
|
||||
else:
|
||||
raise ValueError(
|
||||
"In order to use the ChunkedInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
|
||||
"In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
|
||||
)
|
||||
|
||||
|
||||
@@ -603,55 +603,53 @@ class AnimateDiffFreeNoiseMixin:
|
||||
for block in blocks:
|
||||
self._disable_free_noise_in_block(block)
|
||||
|
||||
def _enable_chunked_inference_motion_modules_(
|
||||
self, motion_modules: List[AnimateDiffTransformer3D], spatial_chunk_size: int
|
||||
def _enable_split_inference_motion_modules_(
|
||||
self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
|
||||
) -> None:
|
||||
for motion_module in motion_modules:
|
||||
motion_module.proj_in = ChunkedInferenceModule(motion_module.proj_in, spatial_chunk_size, 0, ["input"])
|
||||
motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
|
||||
|
||||
for i in range(len(motion_module.transformer_blocks)):
|
||||
motion_module.transformer_blocks[i] = ChunkedInferenceModule(
|
||||
motion_module.transformer_blocks[i] = SplitInferenceModule(
|
||||
motion_module.transformer_blocks[i],
|
||||
spatial_chunk_size,
|
||||
spatial_split_size,
|
||||
0,
|
||||
["hidden_states", "encoder_hidden_states"],
|
||||
)
|
||||
|
||||
motion_module.proj_out = ChunkedInferenceModule(motion_module.proj_out, spatial_chunk_size, 0, ["input"])
|
||||
motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
|
||||
|
||||
def _enable_chunked_inference_attentions_(
|
||||
self, attentions: List[Transformer2DModel], temporal_chunk_size: int
|
||||
def _enable_split_inference_attentions_(
|
||||
self, attentions: List[Transformer2DModel], temporal_split_size: int
|
||||
) -> None:
|
||||
for i in range(len(attentions)):
|
||||
attentions[i] = ChunkedInferenceModule(
|
||||
attentions[i], temporal_chunk_size, 0, ["hidden_states", "encoder_hidden_states"]
|
||||
attentions[i] = SplitInferenceModule(
|
||||
attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"]
|
||||
)
|
||||
|
||||
def _enable_chunked_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_chunk_size: int) -> None:
|
||||
def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None:
|
||||
for i in range(len(resnets)):
|
||||
resnets[i] = ChunkedInferenceModule(resnets[i], temporal_chunk_size, 0, ["input_tensor", "temb"])
|
||||
resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"])
|
||||
|
||||
def _enable_chunked_inference_samplers_(
|
||||
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_chunk_size: int
|
||||
def _enable_split_inference_samplers_(
|
||||
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int
|
||||
) -> None:
|
||||
for i in range(len(samplers)):
|
||||
samplers[i] = ChunkedInferenceModule(samplers[i], temporal_chunk_size, 0, ["hidden_states"])
|
||||
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])
|
||||
|
||||
def enable_free_noise_chunked_inference(
|
||||
self, spatial_chunk_size: int = 256, temporal_chunk_size: int = 16
|
||||
) -> None:
|
||||
def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
if getattr(block, "motion_modules", None) is not None:
|
||||
self._enable_chunked_inference_motion_modules_(block.motion_modules, spatial_chunk_size)
|
||||
self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size)
|
||||
if getattr(block, "attentions", None) is not None:
|
||||
self._enable_chunked_inference_attentions_(block.attentions, temporal_chunk_size)
|
||||
self._enable_split_inference_attentions_(block.attentions, temporal_split_size)
|
||||
if getattr(block, "resnets", None) is not None:
|
||||
self._enable_chunked_inference_resnets_(block.resnets, temporal_chunk_size)
|
||||
self._enable_split_inference_resnets_(block.resnets, temporal_split_size)
|
||||
if getattr(block, "downsamplers", None) is not None:
|
||||
self._enable_chunked_inference_samplers_(block.downsamplers, temporal_chunk_size)
|
||||
self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size)
|
||||
if getattr(block, "upsamplers", None) is not None:
|
||||
self._enable_chunked_inference_samplers_(block.upsamplers, temporal_chunk_size)
|
||||
self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size)
|
||||
|
||||
@property
|
||||
def free_noise_enabled(self):
|
||||
|
||||
Reference in New Issue
Block a user