mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve post-processing performance (#10170)
* Use multiplication instead of division * Add fast path when denormalizing all or none of the images
This commit is contained in:
@@ -236,7 +236,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
`np.ndarray` or `torch.Tensor`:
|
||||
The denormalized image array.
|
||||
"""
|
||||
return (images / 2 + 0.5).clamp(0, 1)
|
||||
return (images * 0.5 + 0.5).clamp(0, 1)
|
||||
|
||||
@staticmethod
|
||||
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
@@ -537,6 +537,26 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
return image
|
||||
|
||||
def _denormalize_conditionally(
|
||||
self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Denormalize a batch of images based on a condition list.
|
||||
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
The input image tensor.
|
||||
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
|
||||
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
|
||||
value of `do_normalize` in the `VaeImageProcessor` config.
|
||||
"""
|
||||
if do_denormalize is None:
|
||||
return self.denormalize(images) if self.config.do_normalize else images
|
||||
|
||||
return torch.stack(
|
||||
[self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]
|
||||
)
|
||||
|
||||
def get_default_height_width(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
@@ -752,12 +772,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
if output_type == "latent":
|
||||
return image
|
||||
|
||||
if do_denormalize is None:
|
||||
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
||||
|
||||
image = torch.stack(
|
||||
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
||||
)
|
||||
image = self._denormalize_conditionally(image, do_denormalize)
|
||||
|
||||
if output_type == "pt":
|
||||
return image
|
||||
@@ -966,12 +981,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
||||
output_type = "np"
|
||||
|
||||
if do_denormalize is None:
|
||||
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
||||
|
||||
image = torch.stack(
|
||||
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
||||
)
|
||||
image = self._denormalize_conditionally(image, do_denormalize)
|
||||
|
||||
image = self.pt_to_numpy(image)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user