From 22d3a82651a2d9436ccd254b696d2c7cd23f3ff0 Mon Sep 17 00:00:00 2001 From: Soof Golan <83900570+soof-golan@users.noreply.github.com> Date: Tue, 10 Dec 2024 20:07:26 +0200 Subject: [PATCH] Improve post-processing performance (#10170) * Use multiplication instead of division * Add fast path when denormalizing all or none of the images --- src/diffusers/image_processor.py | 36 ++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 00d8588d5a..d6913f045a 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -236,7 +236,7 @@ class VaeImageProcessor(ConfigMixin): `np.ndarray` or `torch.Tensor`: The denormalized image array. """ - return (images / 2 + 0.5).clamp(0, 1) + return (images * 0.5 + 0.5).clamp(0, 1) @staticmethod def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: @@ -537,6 +537,26 @@ class VaeImageProcessor(ConfigMixin): return image + def _denormalize_conditionally( + self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None + ) -> torch.Tensor: + r""" + Denormalize a batch of images based on a condition list. + + Args: + images (`torch.Tensor`): + The input image tensor. + do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`): + A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the + value of `do_normalize` in the `VaeImageProcessor` config. + """ + if do_denormalize is None: + return self.denormalize(images) if self.config.do_normalize else images + + return torch.stack( + [self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])] + ) + def get_default_height_width( self, image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], @@ -752,12 +772,7 @@ class VaeImageProcessor(ConfigMixin): if output_type == "latent": return image - if do_denormalize is None: - do_denormalize = [self.config.do_normalize] * image.shape[0] - - image = torch.stack( - [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] - ) + image = self._denormalize_conditionally(image, do_denormalize) if output_type == "pt": return image @@ -966,12 +981,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) output_type = "np" - if do_denormalize is None: - do_denormalize = [self.config.do_normalize] * image.shape[0] - - image = torch.stack( - [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] - ) + image = self._denormalize_conditionally(image, do_denormalize) image = self.pt_to_numpy(image)