mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
make style
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
from typing import Optional, Union, List, Callable, Dict, Any
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
|
||||
debug_save = False
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -38,13 +38,13 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
mask: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
|
||||
# mean of the latent distribution
|
||||
init_latents = [
|
||||
self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean for i in range(batch_size)
|
||||
self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean
|
||||
for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
|
||||
@@ -194,7 +195,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask)
|
||||
noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
@@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
|
||||
def _make_latent_mask(self, latents, mask):
|
||||
if mask is not None:
|
||||
latent_mask = list()
|
||||
latent_mask = []
|
||||
if not isinstance(mask, list):
|
||||
tmp_mask = [mask]
|
||||
else:
|
||||
@@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
m = m / 255.0
|
||||
m = self.image_processor.numpy_to_pil(m)[0]
|
||||
if m.mode != "L":
|
||||
m = m.convert('L')
|
||||
m = m.convert("L")
|
||||
resized = self.image_processor.resize(m, l_height, l_width)
|
||||
if self.debug_save:
|
||||
resized.save("latent_mask.png")
|
||||
|
||||
Reference in New Issue
Block a user