From d1e20be664dd8774e49d1a9d54fd71ec7cd5863c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 30 Aug 2023 14:13:14 +0200 Subject: [PATCH] make style --- .../masked_stable_diffusion_img2img.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/examples/community/masked_stable_diffusion_img2img.py b/examples/community/masked_stable_diffusion_img2img.py index b8182b0b0c..95dff4fd18 100644 --- a/examples/community/masked_stable_diffusion_img2img.py +++ b/examples/community/masked_stable_diffusion_img2img.py @@ -1,14 +1,14 @@ -from typing import Optional, Union, List, Callable, Dict, Any +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch + from diffusers import StableDiffusionImg2ImgPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): - debug_save = False @torch.no_grad() @@ -38,13 +38,13 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, mask: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, ): r""" The call function to the pipeline for generation. @@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): # mean of the latent distribution init_latents = [ - self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean for i in range(batch_size) + self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) @@ -194,7 +195,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask) noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask) - # compute the previous noisy sample x_t -> x_t-1 + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided @@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): def _make_latent_mask(self, latents, mask): if mask is not None: - latent_mask = list() + latent_mask = [] if not isinstance(mask, list): tmp_mask = [mask] else: @@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): m = m / 255.0 m = self.image_processor.numpy_to_pil(m)[0] if m.mode != "L": - m = m.convert('L') + m = m.convert("L") resized = self.image_processor.resize(m, l_height, l_width) if self.debug_save: resized.save("latent_mask.png")