1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

make style

This commit is contained in:
Patrick von Platen
2023-08-30 14:13:14 +02:00
parent af3854d6ad
commit d1e20be664

View File

@@ -1,14 +1,14 @@
from typing import Optional, Union, List, Callable, Dict, Any
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
debug_save = False
@torch.no_grad()
@@ -38,13 +38,13 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
mask: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
):
r"""
The call function to the pipeline for generation.
@@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
# mean of the latent distribution
init_latents = [
self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean for i in range(batch_size)
self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
@@ -194,7 +195,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask)
noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask)
# compute the previous noisy sample x_t -> x_t-1
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
@@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
def _make_latent_mask(self, latents, mask):
if mask is not None:
latent_mask = list()
latent_mask = []
if not isinstance(mask, list):
tmp_mask = [mask]
else:
@@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
m = m / 255.0
m = self.image_processor.numpy_to_pil(m)[0]
if m.mode != "L":
m = m.convert('L')
m = m.convert("L")
resized = self.image_processor.resize(m, l_height, l_width)
if self.debug_save:
resized.save("latent_mask.png")