diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1823636e40..efeb553c19 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1121,7 +1121,6 @@ class FreeNoiseTransformerBlock(nn.Module): if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: - norm_hidden_states = self.norm3(hidden_states) ff_output = self.ff(norm_hidden_states) hidden_states = ff_output + hidden_states diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 17fbdc526a..9f9bc5a46e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2372,8 +2372,6 @@ class AttnProcessor2_0: hidden_states = hidden_states.to(query.dtype) # linear proj - # TODO: figure out a better way to do this - # hidden_states = torch.cat([attn.to_out[1](attn.to_out[0](x)) for x in hidden_states.split(4, dim=0)], dim=0) hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index cf5beb8723..7fcd34a9aa 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -34,50 +34,50 @@ from ..utils.torch_utils import randn_tensor logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class ChunkedInferenceModule(nn.Module): +class SplitInferenceModule(nn.Module): def __init__( self, module: nn.Module, - chunk_size: int = 1, - chunk_dim: int = 0, - input_kwargs_to_chunk: List[str] = ["hidden_states"], + split_size: int = 1, + split_dim: int = 0, + input_kwargs_to_split: List[str] = ["hidden_states"], ) -> None: super().__init__() self.module = module - self.chunk_size = chunk_size - self.chunk_dim = chunk_dim - self.input_kwargs_to_chunk = set(input_kwargs_to_chunk) + self.split_size = split_size + self.split_dim = split_dim + self.input_kwargs_to_split = set(input_kwargs_to_split) def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: - r"""Forward method of `ChunkedInferenceModule`. + r"""Forward method of `SplitInferenceModule`. - All inputs that should be chunked should be passed as keyword arguments. Only those keywords arguments will be - chunked that are specified in `inputs_to_chunk` when initializing the module. + All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be + split that are specified in `inputs_to_split` when initializing the module. """ - chunked_inputs = {} + split_inputs = {} for key in list(kwargs.keys()): - if key not in self.input_kwargs_to_chunk or not torch.is_tensor(kwargs[key]): + if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): continue - chunked_inputs[key] = torch.split(kwargs[key], self.chunk_size, self.chunk_dim) + split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) kwargs.pop(key) results = [] - for chunked_input in zip(*chunked_inputs.values()): - inputs = dict(zip(chunked_inputs.keys(), chunked_input)) + for split_input in zip(*split_inputs.values()): + inputs = dict(zip(split_inputs.keys(), split_input)) inputs.update(kwargs) intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) results.append(intermediate_tensor_or_tensor_tuple) if isinstance(results[0], torch.Tensor): - return torch.cat(results, dim=self.chunk_dim) + return torch.cat(results, dim=self.split_dim) elif isinstance(results[0], tuple): - return tuple([torch.cat(x, dim=self.chunk_dim) for x in zip(*results)]) + return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)]) else: raise ValueError( - "In order to use the ChunkedInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." + "In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." ) @@ -603,55 +603,53 @@ class AnimateDiffFreeNoiseMixin: for block in blocks: self._disable_free_noise_in_block(block) - def _enable_chunked_inference_motion_modules_( - self, motion_modules: List[AnimateDiffTransformer3D], spatial_chunk_size: int + def _enable_split_inference_motion_modules_( + self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int ) -> None: for motion_module in motion_modules: - motion_module.proj_in = ChunkedInferenceModule(motion_module.proj_in, spatial_chunk_size, 0, ["input"]) + motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"]) for i in range(len(motion_module.transformer_blocks)): - motion_module.transformer_blocks[i] = ChunkedInferenceModule( + motion_module.transformer_blocks[i] = SplitInferenceModule( motion_module.transformer_blocks[i], - spatial_chunk_size, + spatial_split_size, 0, ["hidden_states", "encoder_hidden_states"], ) - motion_module.proj_out = ChunkedInferenceModule(motion_module.proj_out, spatial_chunk_size, 0, ["input"]) + motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"]) - def _enable_chunked_inference_attentions_( - self, attentions: List[Transformer2DModel], temporal_chunk_size: int + def _enable_split_inference_attentions_( + self, attentions: List[Transformer2DModel], temporal_split_size: int ) -> None: for i in range(len(attentions)): - attentions[i] = ChunkedInferenceModule( - attentions[i], temporal_chunk_size, 0, ["hidden_states", "encoder_hidden_states"] + attentions[i] = SplitInferenceModule( + attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"] ) - def _enable_chunked_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_chunk_size: int) -> None: + def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None: for i in range(len(resnets)): - resnets[i] = ChunkedInferenceModule(resnets[i], temporal_chunk_size, 0, ["input_tensor", "temb"]) + resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"]) - def _enable_chunked_inference_samplers_( - self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_chunk_size: int + def _enable_split_inference_samplers_( + self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int ) -> None: for i in range(len(samplers)): - samplers[i] = ChunkedInferenceModule(samplers[i], temporal_chunk_size, 0, ["hidden_states"]) + samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) - def enable_free_noise_chunked_inference( - self, spatial_chunk_size: int = 256, temporal_chunk_size: int = 16 - ) -> None: + def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: if getattr(block, "motion_modules", None) is not None: - self._enable_chunked_inference_motion_modules_(block.motion_modules, spatial_chunk_size) + self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size) if getattr(block, "attentions", None) is not None: - self._enable_chunked_inference_attentions_(block.attentions, temporal_chunk_size) + self._enable_split_inference_attentions_(block.attentions, temporal_split_size) if getattr(block, "resnets", None) is not None: - self._enable_chunked_inference_resnets_(block.resnets, temporal_chunk_size) + self._enable_split_inference_resnets_(block.resnets, temporal_split_size) if getattr(block, "downsamplers", None) is not None: - self._enable_chunked_inference_samplers_(block.downsamplers, temporal_chunk_size) + self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size) if getattr(block, "upsamplers", None) is not None: - self._enable_chunked_inference_samplers_(block.upsamplers, temporal_chunk_size) + self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size) @property def free_noise_enabled(self):