mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Flax] Add Flax inpainting impl (#1966)
* [Flax] Add Flax inpainting impl * fixed copies, add README.md * fixed README.md * add test * format * update README.md
This commit is contained in:
47
README.md
47
README.md
@@ -284,6 +284,53 @@ output = pipeline(
|
||||
output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
|
||||
```
|
||||
|
||||
Diffusers also has a Text-guided inpainting pipeline with Flax/Jax
|
||||
|
||||
```python
|
||||
import jax
|
||||
import numpy as np
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
import PIL
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
from diffusers import FlaxStableDiffusionInpaintPipeline
|
||||
|
||||
def download_image(url):
|
||||
response = requests.get(url)
|
||||
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = download_image(img_url).resize((512, 512))
|
||||
mask_image = download_image(mask_url).resize((512, 512))
|
||||
|
||||
pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained("xvjiarui/stable-diffusion-2-inpainting")
|
||||
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
init_image = num_samples * [init_image]
|
||||
mask_image = num_samples * [mask_image]
|
||||
prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
|
||||
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
prompt_ids = shard(prompt_ids)
|
||||
processed_masked_images = shard(processed_masked_images)
|
||||
processed_masks = shard(processed_masks)
|
||||
|
||||
images = pipeline(prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True).images
|
||||
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||
```
|
||||
|
||||
### Image-to-Image text-guided generation with Stable Diffusion
|
||||
|
||||
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
|
||||
|
||||
@@ -182,4 +182,8 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
|
||||
from .pipelines import (
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
)
|
||||
|
||||
@@ -108,4 +108,8 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
|
||||
from .stable_diffusion import (
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
)
|
||||
|
||||
@@ -99,4 +99,5 @@ if is_transformers_available() and is_flax_available():
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
|
||||
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
|
||||
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
@@ -0,0 +1,524 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training.common_utils import shard
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
)
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||
from ..pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
from . import FlaxStableDiffusionPipelineOutput
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||
DEBUG = False
|
||||
|
||||
|
||||
class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
|
||||
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`FlaxAutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`FlaxCLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
|
||||
[`FlaxDPMSolverMultistepScheduler`].
|
||||
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: FlaxAutoencoderKL,
|
||||
text_encoder: FlaxCLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: FlaxUNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
|
||||
],
|
||||
safety_checker: FlaxStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[Image.Image, List[Image.Image]],
|
||||
mask: Union[Image.Image, List[Image.Image]],
|
||||
):
|
||||
if not isinstance(prompt, (str, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if not isinstance(image, (Image.Image, list)):
|
||||
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
image = [image]
|
||||
|
||||
if not isinstance(mask, (Image.Image, list)):
|
||||
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
|
||||
|
||||
if isinstance(mask, Image.Image):
|
||||
mask = [mask]
|
||||
|
||||
processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image])
|
||||
processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask])
|
||||
# processed_masks[processed_masks < 0.5] = 0
|
||||
processed_masks = processed_masks.at[processed_masks < 0.5].set(0)
|
||||
# processed_masks[processed_masks >= 0.5] = 1
|
||||
processed_masks = processed_masks.at[processed_masks >= 0.5].set(1)
|
||||
|
||||
processed_masked_images = processed_images * (processed_masks < 0.5)
|
||||
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
return text_input.input_ids, processed_masked_images, processed_masks
|
||||
|
||||
def _get_has_nsfw_concepts(self, features, params):
|
||||
has_nsfw_concepts = self.safety_checker(features, params)
|
||||
return has_nsfw_concepts
|
||||
|
||||
def _run_safety_checker(self, images, safety_model_params, jit=False):
|
||||
# safety_model_params should already be replicated when jit is True
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
|
||||
|
||||
if jit:
|
||||
features = shard(features)
|
||||
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
|
||||
has_nsfw_concepts = unshard(has_nsfw_concepts)
|
||||
safety_model_params = unreplicate(safety_model_params)
|
||||
else:
|
||||
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
|
||||
|
||||
images_was_copied = False
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
if not images_was_copied:
|
||||
images_was_copied = True
|
||||
images = images.copy()
|
||||
|
||||
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
warnings.warn(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned"
|
||||
" instead. Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
mask: jnp.array,
|
||||
masked_image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int,
|
||||
height: int,
|
||||
width: int,
|
||||
guidance_scale: float,
|
||||
latents: Optional[jnp.array] = None,
|
||||
neg_prompt_ids: Optional[jnp.array] = None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
# get prompt text embeddings
|
||||
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
|
||||
|
||||
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
batch_size = prompt_ids.shape[0]
|
||||
|
||||
max_length = prompt_ids.shape[-1]
|
||||
|
||||
if neg_prompt_ids is None:
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
).input_ids
|
||||
else:
|
||||
uncond_input = neg_prompt_ids
|
||||
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.vae.config.latent_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if latents is None:
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
prng_seed, mask_prng_seed = jax.random.split(prng_seed)
|
||||
|
||||
masked_image_latent_dist = self.vae.apply(
|
||||
{"params": params["vae"]}, masked_image, method=self.vae.encode
|
||||
).latent_dist
|
||||
masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2))
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
del mask_prng_seed
|
||||
|
||||
mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest")
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
def loop_body(step, args):
|
||||
latents, mask, masked_image_latents, scheduler_state = args
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
latents_input = jnp.concatenate([latents] * 2)
|
||||
mask_input = jnp.concatenate([mask] * 2)
|
||||
masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2)
|
||||
|
||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||
|
||||
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet.apply(
|
||||
{"params": params["unet"]},
|
||||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=context,
|
||||
).sample
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
||||
return latents, mask, masked_image_latents, scheduler_state
|
||||
|
||||
scheduler_state = self.scheduler.set_timesteps(
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
|
||||
)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * params["scheduler"].init_noise_sigma
|
||||
|
||||
if DEBUG:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
latents, mask, masked_image_latents, scheduler_state = loop_body(
|
||||
i, (latents, mask, masked_image_latents, scheduler_state)
|
||||
)
|
||||
else:
|
||||
latents, _, _, _ = jax.lax.fori_loop(
|
||||
0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state)
|
||||
)
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
return image
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
mask: jnp.array,
|
||||
masked_image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int = 50,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||
latents: jnp.array = None,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
|
||||
by sampling using the supplied random `generator`.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic")
|
||||
mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest")
|
||||
|
||||
if isinstance(guidance_scale, float):
|
||||
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
|
||||
if len(prompt_ids.shape) > 2:
|
||||
# Assume sharded
|
||||
guidance_scale = guidance_scale[:, None]
|
||||
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self,
|
||||
prompt_ids,
|
||||
mask,
|
||||
masked_image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
else:
|
||||
images = self._generate(
|
||||
prompt_ids,
|
||||
mask,
|
||||
masked_image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_params = params["safety_checker"]
|
||||
images_uint8_casted = (images * 255).round().astype("uint8")
|
||||
num_devices, batch_size = images.shape[:2]
|
||||
|
||||
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
||||
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
|
||||
images = np.asarray(images)
|
||||
|
||||
# block images
|
||||
if any(has_nsfw_concept):
|
||||
for i, is_nsfw in enumerate(has_nsfw_concept):
|
||||
if is_nsfw:
|
||||
images[i] = np.asarray(images_uint8_casted[i])
|
||||
|
||||
images = images.reshape(num_devices, batch_size, height, width, 3)
|
||||
else:
|
||||
images = np.asarray(images)
|
||||
has_nsfw_concept = False
|
||||
|
||||
if not return_dict:
|
||||
return (images, has_nsfw_concept)
|
||||
|
||||
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
|
||||
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
|
||||
@partial(
|
||||
jax.pmap,
|
||||
in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0),
|
||||
static_broadcasted_argnums=(0, 6, 7, 8),
|
||||
)
|
||||
def _p_generate(
|
||||
pipe,
|
||||
prompt_ids,
|
||||
mask,
|
||||
masked_image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
):
|
||||
return pipe._generate(
|
||||
prompt_ids,
|
||||
mask,
|
||||
masked_image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
|
||||
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0,))
|
||||
def _p_get_has_nsfw_concepts(pipe, features, params):
|
||||
return pipe._get_has_nsfw_concepts(features, params)
|
||||
|
||||
|
||||
def unshard(x: jnp.ndarray):
|
||||
# einops.rearrange(x, 'd b ... -> (d b) ...')
|
||||
num_devices, batch_size = x.shape[:2]
|
||||
rest = x.shape[2:]
|
||||
return x.reshape(num_devices * batch_size, *rest)
|
||||
|
||||
|
||||
def preprocess_image(image, dtype):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = jnp.array(image).astype(dtype) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask, dtype):
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize((w, h))
|
||||
mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0
|
||||
mask = jnp.expand_dims(mask, axis=(0, 1))
|
||||
|
||||
return mask
|
||||
@@ -19,6 +19,21 @@ class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
|
||||
class FlaxStableDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax", "transformers"]
|
||||
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import FlaxStableDiffusionInpaintPipeline
|
||||
from diffusers.utils import is_flax_available, load_image, slow
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
|
||||
def test_stable_diffusion_inpaint_pipeline(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-inpaint/init_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
|
||||
)
|
||||
|
||||
model_id = "xvjiarui/stable-diffusion-2-inpainting"
|
||||
pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
|
||||
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
init_image = num_samples * [init_image]
|
||||
mask_image = num_samples * [mask_image]
|
||||
prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
prompt_ids = shard(prompt_ids)
|
||||
processed_masked_images = shard(processed_masked_images)
|
||||
processed_masks = shard(processed_masks)
|
||||
|
||||
output = pipeline(
|
||||
prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
|
||||
)
|
||||
|
||||
images = output.images.reshape(num_samples, 512, 512, 3)
|
||||
|
||||
image_slice = images[0, 253:256, 253:256, -1]
|
||||
|
||||
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
|
||||
expected_slice = jnp.array(
|
||||
[0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084]
|
||||
)
|
||||
print(f"output_slice: {output_slice}")
|
||||
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
|
||||
Reference in New Issue
Block a user