mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add model (#3230)
* add * clean * up * clean up more * fix more tests * Improve docs further * improve * more fixes docs * Improve docs more * Update src/diffusers/models/unet_2d_condition.py * fix * up * update doc links * make fix-copies * add safety checker and watermarker to stage 3 doc page code snippets * speed optimizations docs * memory optimization docs * make style * add watermarking snippets to doc string examples * make style * use pt_to_pil helper functions in doc strings * skip mps tests * Improve safety * make style * new logic * fix * fix bad onnx design * make new stable diffusion upscale pipeline model arguments optional * define has_nsfw_concept when non-pil output type * lowercase linked to notebook name --------- Co-authored-by: William Berman <WLBberman@gmail.com>
This commit is contained in:
committed by
Daniel Gu
parent
711119ae7a
commit
416f31adf8
@@ -152,6 +152,8 @@
|
||||
title: DDPM
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/if
|
||||
title: IF
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/paint_by_example
|
||||
|
||||
523
docs/source/en/api/pipelines/if.mdx
Normal file
523
docs/source/en/api/pipelines/if.mdx
Normal file
@@ -0,0 +1,523 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# IF
|
||||
|
||||
## Overview
|
||||
|
||||
DeepFloyd IF is a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding.
|
||||
The model is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules:
|
||||
- Stage 1: a base model that generates 64x64 px image based on text prompt,
|
||||
- Stage 2: a 64x64 px => 256x256 px super-resolution model, and a
|
||||
- Stage 3: a 256x256 px => 1024x1024 px super-resolution model
|
||||
Stage 1 and Stage 2 utilize a frozen text encoder based on the T5 transformer to extract text embeddings,
|
||||
which are then fed into a UNet architecture enhanced with cross-attention and attention pooling.
|
||||
Stage 3 is [Stability's x4 Upscaling model](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler).
|
||||
The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset.
|
||||
Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.
|
||||
|
||||
## Usage
|
||||
|
||||
Before you can use IF, you need to accept its usage conditions. To do so:
|
||||
1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be loggin in
|
||||
2. Accept the license on the model card of [DeepFloyd/IF-I-IF-v1.0](https://huggingface.co/DeepFloyd/IF-I-IF-v1.0) and [DeepFloyd/IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)
|
||||
3. Make sure to login locally. Install `huggingface_hub`
|
||||
```sh
|
||||
pip install huggingface_hub --upgrade
|
||||
```
|
||||
|
||||
run the login function in a Python shell
|
||||
|
||||
```py
|
||||
from huggingface_hub import login
|
||||
|
||||
login()
|
||||
```
|
||||
|
||||
and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).
|
||||
|
||||
Next we install `diffusers` and dependencies:
|
||||
|
||||
```sh
|
||||
pip install diffusers accelerate transformers safetensors
|
||||
```
|
||||
|
||||
The following sections give more in-detail examples of how to use IF. Specifically:
|
||||
|
||||
- [Text-to-Image Generation](#text-to-image-generation)
|
||||
- [Image-to-Image Generation](#text-guided-image-to-image-generation)
|
||||
- [Inpainting](#text-guided-inpainting-generation)
|
||||
- [Reusing model weights](#converting-between-different-pipelines)
|
||||
- [Speed optimization](#optimizing-for-speed)
|
||||
- [Memory optimization](#optimizing-for-memory)
|
||||
|
||||
**Available checkpoints**
|
||||
- *Stage-1*
|
||||
- [DeepFloyd/IF-I-IF-v1.0](https://huggingface.co/DeepFloyd/IF-I-IF-v1.0)
|
||||
- [DeepFloyd/IF-I-L-v1.0](https://huggingface.co/DeepFloyd/IF-I-L-v1.0)
|
||||
- [DeepFloyd/IF-I-M-v1.0](https://huggingface.co/DeepFloyd/IF-I-M-v1.0)
|
||||
|
||||
- *Stage-2*
|
||||
- [DeepFloyd/IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)
|
||||
- [DeepFloyd/IF-II-M-v1.0](https://huggingface.co/DeepFloyd/IF-II-M-v1.0)
|
||||
|
||||
- *Stage-3*
|
||||
- [stabilityai/stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler)
|
||||
|
||||
**Demo**
|
||||
[](https://huggingface.co/spaces/DeepFloyd/IF)
|
||||
|
||||
**Google Colab**
|
||||
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
|
||||
|
||||
### Text-to-Image Generation
|
||||
|
||||
By default diffusers makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings)
|
||||
to run the whole IF pipeline with as little as 14 GB of VRAM.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import pt_to_pil
|
||||
import torch
|
||||
|
||||
# stage 1
|
||||
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
stage_1.enable_model_cpu_offload()
|
||||
|
||||
# stage 2
|
||||
stage_2 = DiffusionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
stage_2.enable_model_cpu_offload()
|
||||
|
||||
# stage 3
|
||||
safety_modules = {
|
||||
"feature_extractor": stage_1.feature_extractor,
|
||||
"safety_checker": stage_1.safety_checker,
|
||||
"watermarker": stage_1.watermarker,
|
||||
}
|
||||
stage_3 = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
|
||||
)
|
||||
stage_3.enable_model_cpu_offload()
|
||||
|
||||
prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
|
||||
generator = torch.manual_seed(1)
|
||||
|
||||
# text embeds
|
||||
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
|
||||
|
||||
# stage 1
|
||||
image = stage_1(
|
||||
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_I.png")
|
||||
|
||||
# stage 2
|
||||
image = stage_2(
|
||||
image=image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_II.png")
|
||||
|
||||
# stage 3
|
||||
image = stage_3(prompt=prompt, image=image, noise_level=100, generator=generator).images
|
||||
image[0].save("./if_stage_III.png")
|
||||
```
|
||||
|
||||
### Text Guided Image-to-Image Generation
|
||||
|
||||
The same IF model weights can be used for text-guided image-to-image translation or image variation.
|
||||
In this case just make sure to load the weights using the [`IFInpaintingPipeline`] and [`IFInpaintingSuperResolutionPipeline`] pipelines.
|
||||
|
||||
**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
|
||||
without loading them twice by making use of the [`~DiffusionPipeline.components()`] function as explained [here](#converting-between-different-pipelines).
|
||||
|
||||
```python
|
||||
from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
|
||||
from diffusers.utils import pt_to_pil
|
||||
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
# download image
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
response = requests.get(url)
|
||||
original_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
original_image = original_image.resize((768, 512))
|
||||
|
||||
# stage 1
|
||||
stage_1 = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
stage_1.enable_model_cpu_offload()
|
||||
|
||||
# stage 2
|
||||
stage_2 = IFImg2ImgSuperResolutionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
stage_2.enable_model_cpu_offload()
|
||||
|
||||
# stage 3
|
||||
safety_modules = {
|
||||
"feature_extractor": stage_1.feature_extractor,
|
||||
"safety_checker": stage_1.safety_checker,
|
||||
"watermarker": stage_1.watermarker,
|
||||
}
|
||||
stage_3 = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
|
||||
)
|
||||
stage_3.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A fantasy landscape in style minecraft"
|
||||
generator = torch.manual_seed(1)
|
||||
|
||||
# text embeds
|
||||
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
|
||||
|
||||
# stage 1
|
||||
image = stage_1(
|
||||
image=original_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_I.png")
|
||||
|
||||
# stage 2
|
||||
image = stage_2(
|
||||
image=image,
|
||||
original_image=original_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_II.png")
|
||||
|
||||
# stage 3
|
||||
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images
|
||||
image[0].save("./if_stage_III.png")
|
||||
```
|
||||
|
||||
### Text Guided Inpainting Generation
|
||||
|
||||
The same IF model weights can be used for text-guided image-to-image translation or image variation.
|
||||
In this case just make sure to load the weights using the [`IFInpaintingPipeline`] and [`IFInpaintingSuperResolutionPipeline`] pipelines.
|
||||
|
||||
**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
|
||||
without loading them twice by making use of the [`~DiffusionPipeline.components()`] function as explained [here](#converting-between-different-pipelines).
|
||||
|
||||
```python
|
||||
from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline
|
||||
from diffusers.utils import pt_to_pil
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
# download image
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png"
|
||||
response = requests.get(url)
|
||||
original_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
original_image = original_image
|
||||
|
||||
# download mask
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png"
|
||||
response = requests.get(url)
|
||||
mask_image = Image.open(BytesIO(response.content))
|
||||
mask_image = mask_image
|
||||
|
||||
# stage 1
|
||||
stage_1 = IFInpaintingPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
stage_1.enable_model_cpu_offload()
|
||||
|
||||
# stage 2
|
||||
stage_2 = IFInpaintingSuperResolutionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
stage_2.enable_model_cpu_offload()
|
||||
|
||||
# stage 3
|
||||
safety_modules = {
|
||||
"feature_extractor": stage_1.feature_extractor,
|
||||
"safety_checker": stage_1.safety_checker,
|
||||
"watermarker": stage_1.watermarker,
|
||||
}
|
||||
stage_3 = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
|
||||
)
|
||||
stage_3.enable_model_cpu_offload()
|
||||
|
||||
prompt = "blue sunglasses"
|
||||
generator = torch.manual_seed(1)
|
||||
|
||||
# text embeds
|
||||
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
|
||||
|
||||
# stage 1
|
||||
image = stage_1(
|
||||
image=original_image,
|
||||
mask_image=mask_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_I.png")
|
||||
|
||||
# stage 2
|
||||
image = stage_2(
|
||||
image=image,
|
||||
original_image=original_image,
|
||||
mask_image=mask_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_II.png")
|
||||
|
||||
# stage 3
|
||||
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images
|
||||
image[0].save("./if_stage_III.png")
|
||||
```
|
||||
|
||||
### Converting between different pipelines
|
||||
|
||||
In addition to being loaded with `from_pretrained`, Pipelines can also be loaded directly from each other.
|
||||
|
||||
```python
|
||||
from diffusers import IFPipeline, IFSuperResolutionPipeline
|
||||
|
||||
pipe_1 = IFPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0")
|
||||
pipe_2 = IFSuperResolutionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0")
|
||||
|
||||
|
||||
from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline
|
||||
|
||||
pipe_1 = IFImg2ImgPipeline(**pipe_1.components)
|
||||
pipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components)
|
||||
|
||||
|
||||
from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline
|
||||
|
||||
pipe_1 = IFInpaintingPipeline(**pipe_1.components)
|
||||
pipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components)
|
||||
```
|
||||
|
||||
### Optimizing for speed
|
||||
|
||||
The simplest optimization to run IF faster is to move all model components to the GPU.
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
You can also run the diffusion process for a shorter number of timesteps.
|
||||
|
||||
This can either be done with the `num_inference_steps` argument
|
||||
|
||||
```py
|
||||
pipe("<prompt>", num_inference_steps=30)
|
||||
```
|
||||
|
||||
Or with the `timesteps` argument
|
||||
|
||||
```py
|
||||
from diffusers.pipelines.deepfloyd_if import fast27_timesteps
|
||||
|
||||
pipe("<prompt>", timesteps=fast27_timesteps)
|
||||
```
|
||||
|
||||
When doing image variation or inpainting, you can also decrease the number of timesteps
|
||||
with the strength argument. The strength argument is the amount of noise to add to
|
||||
the input image which also determines how many steps to run in the denoising process.
|
||||
A smaller number will vary the image less but run faster.
|
||||
|
||||
```py
|
||||
pipe = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(image=image, prompt="<prompt>", strength=0.3).images
|
||||
```
|
||||
|
||||
You can also use [`torch.compile`](../../optimization/torch2.0). Note that we have not exhaustively tested `torch.compile`
|
||||
with IF and it might not give expected results.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
pipe.text_encoder = torch.compile(pipe.text_encoder)
|
||||
pipe.unet = torch.compile(pipe.unet)
|
||||
```
|
||||
|
||||
### Optimizing for memory
|
||||
|
||||
When optimizing for GPU memory, we can use the standard diffusers cpu offloading APIs.
|
||||
|
||||
Either the model based CPU offloading,
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
or the more aggressive layer based CPU offloading.
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
Additionally, T5 can be loaded in 8bit precision
|
||||
|
||||
```py
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
"DeepFloyd/IF-I-IF-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit"
|
||||
)
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-I-IF-v1.0",
|
||||
text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
|
||||
unet=None,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
prompt_embeds, negative_embeds = pipe.encode_prompt("<prompt>")
|
||||
```
|
||||
|
||||
For CPU RAM constrained machines like google colab free tier where we can't load all
|
||||
model components to the CPU at once, we can manually only load the pipeline with
|
||||
the text encoder or unet when the respective model components are needed.
|
||||
|
||||
```py
|
||||
from diffusers import IFPipeline, IFSuperResolutionPipeline
|
||||
import torch
|
||||
import gc
|
||||
from transformers import T5EncoderModel
|
||||
from diffusers.utils import pt_to_pil
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
"DeepFloyd/IF-I-IF-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit"
|
||||
)
|
||||
|
||||
# text to image
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-I-IF-v1.0",
|
||||
text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
|
||||
unet=None,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
|
||||
prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
|
||||
|
||||
# Remove the pipeline so we can re-load the pipeline with the unet
|
||||
del text_encoder
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pipe = IFPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-I-IF-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
image = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
output_type="pt",
|
||||
generator=generator,
|
||||
).images
|
||||
|
||||
pt_to_pil(image)[0].save("./if_stage_I.png")
|
||||
|
||||
# Remove the pipeline so we can load the super-resolution pipeline
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# First super resolution
|
||||
|
||||
pipe = IFSuperResolutionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
image = pipe(
|
||||
image=image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
output_type="pt",
|
||||
generator=generator,
|
||||
).images
|
||||
|
||||
pt_to_pil(image)[0].save("./if_stage_II.png")
|
||||
```
|
||||
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|---|---|:---:|
|
||||
| [pipeline_if.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py) | *Text-to-Image Generation* | - |
|
||||
| [pipeline_if_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py) | *Text-to-Image Generation* | - |
|
||||
| [pipeline_if_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py) | *Image-to-Image Generation* | - |
|
||||
| [pipeline_if_img2img_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py) | *Image-to-Image Generation* | - |
|
||||
| [pipeline_if_inpainting.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py) | *Image-to-Image Generation* | - |
|
||||
| [pipeline_if_inpainting_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py) | *Image-to-Image Generation* | - |
|
||||
|
||||
## IFPipeline
|
||||
[[autodoc]] IFPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFSuperResolutionPipeline
|
||||
[[autodoc]] IFSuperResolutionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFImg2ImgPipeline
|
||||
[[autodoc]] IFImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFImg2ImgSuperResolutionPipeline
|
||||
[[autodoc]] IFImg2ImgSuperResolutionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFInpaintingPipeline
|
||||
[[autodoc]] IFInpaintingPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFInpaintingSuperResolutionPipeline
|
||||
[[autodoc]] IFInpaintingSuperResolutionPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -51,6 +51,9 @@ available a colab notebook to directly try them out.
|
||||
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [if](./if) | [**IF**](https://github.com/deep-floyd/IF) | Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
|
||||
| [if_img2img](./if) | [**IF**](https://github.com/deep-floyd/IF) | Image-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
|
||||
| [if_inpainting](./if) | [**IF**](https://github.com/deep-floyd/IF) | Image-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
|
||||
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
|
||||
@@ -58,6 +58,9 @@ The library has three main components:
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./api/pipelines/ddim) | [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [if](./if) | [**IF**](./api/pipelines/if) | Image Generation |
|
||||
| [if_img2img](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |
|
||||
| [if_inpainting](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
|
||||
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
|
||||
1257
scripts/convert_if.py
Normal file
1257
scripts/convert_if.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -115,6 +115,12 @@ else:
|
||||
AudioLDMPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
ImageTextPipelineOutput,
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
|
||||
@@ -109,6 +109,7 @@ class ConfigMixin:
|
||||
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
||||
# or solve in a more general way.
|
||||
kwargs.pop("kwargs", None)
|
||||
|
||||
if not hasattr(self, "_internal_dict"):
|
||||
internal_dict = kwargs
|
||||
else:
|
||||
@@ -550,6 +551,9 @@ class ConfigMixin:
|
||||
return value
|
||||
|
||||
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
||||
# Don't save "_ignore_files"
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
|
||||
@@ -377,3 +377,69 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
|
||||
conditioning = timesteps_emb + class_labels # (N, D)
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class TextTimeEmbedding(nn.Module):
|
||||
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(encoder_dim)
|
||||
self.pool = AttentionPooling(num_heads, encoder_dim)
|
||||
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
||||
self.norm2 = nn.LayerNorm(time_embed_dim)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.pool(hidden_states)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionPooling(nn.Module):
|
||||
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
||||
|
||||
def __init__(self, num_heads, embed_dim, dtype=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
||||
self.num_heads = num_heads
|
||||
self.dim_per_head = embed_dim // self.num_heads
|
||||
|
||||
def forward(self, x):
|
||||
bs, length, width = x.size()
|
||||
|
||||
def shape(x):
|
||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
||||
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
||||
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
||||
|
||||
# (bs*n_heads, class_token_length, dim_per_head)
|
||||
q = shape(self.q_proj(class_token))
|
||||
# (bs*n_heads, length+class_token_length, dim_per_head)
|
||||
k = shape(self.k_proj(x))
|
||||
v = shape(self.v_proj(x))
|
||||
|
||||
# (bs*n_heads, class_token_length, length+class_token_length):
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
|
||||
# (bs*n_heads, dim_per_head, class_token_length)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
|
||||
# (bs, length+1, width)
|
||||
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
||||
|
||||
return a[:, 0, :] # cls_token
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
@@ -60,7 +61,8 @@ if is_safetensors_available():
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
||||
return next(parameters_and_buffers).device
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
@@ -75,7 +77,8 @@ def get_parameter_device(parameter: torch.nn.Module):
|
||||
|
||||
def get_parameter_dtype(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).dtype
|
||||
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
||||
return next(parameters_and_buffers).dtype
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
@@ -97,11 +97,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to None):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
time_embedding_type (`str`, *optional*, default to `positional`):
|
||||
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
||||
time_embedding_dim (`int`, *optional*, default to `None`):
|
||||
An optional override for the dimension of the projected time embedding.
|
||||
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
||||
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
||||
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
||||
@@ -155,12 +160,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_skip_time_act: bool = False,
|
||||
resnet_out_scale_factor: int = 1.0,
|
||||
time_embedding_type: str = "positional",
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
time_embedding_act_fn: Optional[str] = None,
|
||||
timestep_post_act: Optional[str] = None,
|
||||
time_cond_proj_dim: Optional[int] = None,
|
||||
@@ -170,6 +177,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
class_embeddings_concat: bool = False,
|
||||
mid_block_only_cross_attention: Optional[bool] = None,
|
||||
cross_attention_norm: Optional[str] = None,
|
||||
addition_embed_type_num_heads=64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -214,7 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
time_embed_dim = block_out_channels[0] * 2
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
||||
if time_embed_dim % 2 != 0:
|
||||
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
@@ -222,7 +230,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
)
|
||||
timestep_input_dim = time_embed_dim
|
||||
elif time_embedding_type == "positional":
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
@@ -273,6 +281,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
|
||||
|
||||
if time_embedding_act_fn is None:
|
||||
self.time_embed_act = None
|
||||
elif time_embedding_act_fn == "swish":
|
||||
@@ -684,6 +704,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
else:
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
emb = emb + aug_emb
|
||||
|
||||
if self.time_embed_act is not None:
|
||||
emb = self.time_embed_act(emb)
|
||||
|
||||
|
||||
@@ -44,6 +44,14 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
|
||||
54
src/diffusers/pipelines/deepfloyd_if/__init__.py
Normal file
54
src/diffusers/pipelines/deepfloyd_if/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
|
||||
from .timesteps import (
|
||||
fast27_timesteps,
|
||||
smart27_timesteps,
|
||||
smart50_timesteps,
|
||||
smart100_timesteps,
|
||||
smart185_timesteps,
|
||||
super27_timesteps,
|
||||
super40_timesteps,
|
||||
super100_timesteps,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IFPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
Output class for Stable Diffusion pipelines.
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
nsfw_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content or a watermark. `None` if safety checking could not be performed.
|
||||
watermark_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
|
||||
checking could not be performed.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_detected: Optional[List[bool]]
|
||||
watermark_detected: Optional[List[bool]]
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_if import IFPipeline
|
||||
from .pipeline_if_img2img import IFImg2ImgPipeline
|
||||
from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline
|
||||
from .pipeline_if_inpainting import IFInpaintingPipeline
|
||||
from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline
|
||||
from .pipeline_if_superresolution import IFSuperResolutionPipeline
|
||||
from .safety_checker import IFSafetyChecker
|
||||
from .watermark import IFWatermarker
|
||||
854
src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
Normal file
854
src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
Normal file
@@ -0,0 +1,854 @@
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import IFPipelineOutput
|
||||
from .safety_checker import IFSafetyChecker
|
||||
from .watermark import IFWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
|
||||
>>> from diffusers.utils import pt_to_pil
|
||||
>>> import torch
|
||||
|
||||
>>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
|
||||
>>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
|
||||
|
||||
>>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images
|
||||
|
||||
>>> # save intermediate image
|
||||
>>> pil_image = pt_to_pil(image)
|
||||
>>> pil_image[0].save("./if_stage_I.png")
|
||||
|
||||
>>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
|
||||
... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> super_res_1_pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = super_res_1_pipe(
|
||||
... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt"
|
||||
... ).images
|
||||
|
||||
>>> # save intermediate image
|
||||
>>> pil_image = pt_to_pil(image)
|
||||
>>> pil_image[0].save("./if_stage_I.png")
|
||||
|
||||
>>> safety_modules = {
|
||||
... "feature_extractor": pipe.feature_extractor,
|
||||
... "safety_checker": pipe.safety_checker,
|
||||
... "watermarker": pipe.watermarker,
|
||||
... }
|
||||
>>> super_res_2_pipe = DiffusionPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> super_res_2_pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = super_res_2_pipe(
|
||||
... prompt=prompt,
|
||||
... image=image,
|
||||
... ).images
|
||||
>>> image[0].save("./if_stage_II.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class IFPipeline(DiffusionPipeline):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
unet: UNet2DConditionModel
|
||||
scheduler: DDPMScheduler
|
||||
|
||||
feature_extractor: Optional[CLIPImageProcessor]
|
||||
safety_checker: Optional[IFSafetyChecker]
|
||||
|
||||
watermarker: Optional[IFWatermarker]
|
||||
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
safety_checker: Optional[IFSafetyChecker],
|
||||
feature_extractor: Optional[CLIPImageProcessor],
|
||||
watermarker: Optional[IFWatermarker],
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the IF license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
watermarker=watermarker,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.text_encoder,
|
||||
self.unet,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
|
||||
if self.text_encoder is not None:
|
||||
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
||||
|
||||
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
||||
# previous model. This will cause both models to be present on the device at the same time.
|
||||
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
||||
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
||||
# the GPU.
|
||||
self.text_encoder_offload_hook = hook
|
||||
|
||||
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
||||
|
||||
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
||||
self.unet_offload_hook = hook
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
||||
max_length = 77
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
dtype = self.unet.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, nsfw_detected, watermark_detected = self.safety_checker(
|
||||
images=image,
|
||||
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
||||
)
|
||||
else:
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
|
||||
return intermediate_images
|
||||
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 100,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated image.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
||||
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
||||
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
||||
or watermarked content, according to the `safety_checker`.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
# 2. Define call parameters
|
||||
height = height or self.unet.config.sample_size
|
||||
width = width or self.unet.config.sample_size
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if timesteps is not None:
|
||||
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare intermediate images
|
||||
intermediate_images = self.prepare_intermediate_images(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.unet.config.in_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# HACK: see comment in `enable_model_cpu_offload`
|
||||
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
||||
self.text_encoder_offload_hook.offload()
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
model_input = (
|
||||
torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
|
||||
)
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, intermediate_images)
|
||||
|
||||
image = intermediate_images
|
||||
|
||||
if output_type == "pil":
|
||||
# 8. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 9. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# 10. Convert to PIL
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# 11. Apply watermark
|
||||
if self.watermarker is not None:
|
||||
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
|
||||
elif output_type == "pt":
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
else:
|
||||
# 8. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 9. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, nsfw_detected, watermark_detected)
|
||||
|
||||
return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
|
||||
979
src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
Normal file
979
src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
Normal file
@@ -0,0 +1,979 @@
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import IFPipelineOutput
|
||||
from .safety_checker import IFSafetyChecker
|
||||
from .watermark import IFWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
|
||||
w, h = images.size
|
||||
|
||||
coef = w / h
|
||||
|
||||
w, h = img_size, img_size
|
||||
|
||||
if coef >= 1:
|
||||
w = int(round(img_size / 8 * coef) * 8)
|
||||
else:
|
||||
h = int(round(img_size / 8 / coef) * 8)
|
||||
|
||||
images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
|
||||
>>> from diffusers.utils import pt_to_pil
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from io import BytesIO
|
||||
|
||||
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
>>> response = requests.get(url)
|
||||
>>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
>>> original_image = original_image.resize((768, 512))
|
||||
|
||||
>>> pipe = IFImg2ImgPipeline.from_pretrained(
|
||||
... "DeepFloyd/IF-I-IF-v1.0",
|
||||
... variant="fp16",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "A fantasy landscape in style minecraft"
|
||||
>>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
|
||||
|
||||
>>> image = pipe(
|
||||
... image=original_image,
|
||||
... prompt_embeds=prompt_embeds,
|
||||
... negative_prompt_embeds=negative_embeds,
|
||||
... output_type="pt",
|
||||
... ).images
|
||||
|
||||
>>> # save intermediate image
|
||||
>>> pil_image = pt_to_pil(image)
|
||||
>>> pil_image[0].save("./if_stage_I.png")
|
||||
|
||||
>>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
|
||||
... "DeepFloyd/IF-II-L-v1.0",
|
||||
... text_encoder=None,
|
||||
... variant="fp16",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> super_res_1_pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = super_res_1_pipe(
|
||||
... image=image,
|
||||
... original_image=original_image,
|
||||
... prompt_embeds=prompt_embeds,
|
||||
... negative_prompt_embeds=negative_embeds,
|
||||
... ).images
|
||||
>>> image[0].save("./if_stage_II.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class IFImg2ImgPipeline(DiffusionPipeline):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
unet: UNet2DConditionModel
|
||||
scheduler: DDPMScheduler
|
||||
|
||||
feature_extractor: Optional[CLIPImageProcessor]
|
||||
safety_checker: Optional[IFSafetyChecker]
|
||||
|
||||
watermarker: Optional[IFWatermarker]
|
||||
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
safety_checker: Optional[IFSafetyChecker],
|
||||
feature_extractor: Optional[CLIPImageProcessor],
|
||||
watermarker: Optional[IFWatermarker],
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the IF license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
watermarker=watermarker,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.text_encoder,
|
||||
self.unet,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
|
||||
if self.text_encoder is not None:
|
||||
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
||||
|
||||
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
||||
# previous model. This will cause both models to be present on the device at the same time.
|
||||
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
||||
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
||||
# the GPU.
|
||||
self.text_encoder_offload_hook = hook
|
||||
|
||||
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
||||
|
||||
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
||||
self.unet_offload_hook = hook
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
||||
max_length = 77
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
dtype = self.unet.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, nsfw_detected, watermark_detected = self.safety_checker(
|
||||
images=image,
|
||||
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
||||
)
|
||||
else:
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
batch_size,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if isinstance(image, list):
|
||||
check_image_type = image[0]
|
||||
else:
|
||||
check_image_type = image
|
||||
|
||||
if (
|
||||
not isinstance(check_image_type, torch.Tensor)
|
||||
and not isinstance(check_image_type, PIL.Image.Image)
|
||||
and not isinstance(check_image_type, np.ndarray)
|
||||
):
|
||||
raise ValueError(
|
||||
"`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
|
||||
f" {type(check_image_type)}"
|
||||
)
|
||||
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image_batch_size = image.shape[0]
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image_batch_size = 1
|
||||
elif isinstance(image, np.ndarray):
|
||||
image_batch_size = image.shape[0]
|
||||
else:
|
||||
assert False
|
||||
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor:
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
def numpy_to_pt(images):
|
||||
if images.ndim == 3:
|
||||
images = images[..., None]
|
||||
|
||||
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
||||
return images
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
new_image = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = resize(image_, self.unet.sample_size)
|
||||
image_ = np.array(image_)
|
||||
image_ = image_.astype(np.float32)
|
||||
image_ = image_ / 127.5 - 1
|
||||
new_image.append(image_)
|
||||
|
||||
image = new_image
|
||||
|
||||
image = np.stack(image, axis=0) # to np
|
||||
image = numpy_to_pt(image) # to pt
|
||||
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
||||
image = numpy_to_pt(image)
|
||||
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||
|
||||
return image
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_intermediate_images(
|
||||
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
|
||||
):
|
||||
_, channels, height, width = image.shape
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
shape = (batch_size, channels, height, width)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
image = image.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
image = self.scheduler.add_noise(image, noise, timestep)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
|
||||
] = None,
|
||||
strength: float = 0.7,
|
||||
num_inference_steps: int = 80,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 10.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
||||
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
||||
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
||||
or watermarked content, according to the `safety_checker`.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
self.check_inputs(
|
||||
prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
dtype = prompt_embeds.dtype
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if timesteps is not None:
|
||||
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
||||
|
||||
# 5. Prepare intermediate images
|
||||
image = self.preprocess_image(image)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
noise_timestep = timesteps[0:1]
|
||||
noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
intermediate_images = self.prepare_intermediate_images(
|
||||
image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# HACK: see comment in `enable_model_cpu_offload`
|
||||
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
||||
self.text_encoder_offload_hook.offload()
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
model_input = (
|
||||
torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
|
||||
)
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, intermediate_images)
|
||||
|
||||
image = intermediate_images
|
||||
|
||||
if output_type == "pil":
|
||||
# 8. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 9. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# 10. Convert to PIL
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# 11. Apply watermark
|
||||
if self.watermarker is not None:
|
||||
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
|
||||
elif output_type == "pt":
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
else:
|
||||
# 8. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 9. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, nsfw_detected, watermark_detected)
|
||||
|
||||
return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
|
||||
File diff suppressed because it is too large
Load Diff
1098
src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
Normal file
1098
src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,947 @@
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import IFPipelineOutput
|
||||
from .safety_checker import IFSafetyChecker
|
||||
from .watermark import IFWatermarker
|
||||
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
|
||||
>>> from diffusers.utils import pt_to_pil
|
||||
>>> import torch
|
||||
|
||||
>>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
|
||||
>>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
|
||||
|
||||
>>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images
|
||||
|
||||
>>> # save intermediate image
|
||||
>>> pil_image = pt_to_pil(image)
|
||||
>>> pil_image[0].save("./if_stage_I.png")
|
||||
|
||||
>>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
|
||||
... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> super_res_1_pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = super_res_1_pipe(
|
||||
... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds
|
||||
... ).images
|
||||
>>> image[0].save("./if_stage_II.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class IFSuperResolutionPipeline(DiffusionPipeline):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
unet: UNet2DConditionModel
|
||||
scheduler: DDPMScheduler
|
||||
image_noising_scheduler: DDPMScheduler
|
||||
|
||||
feature_extractor: Optional[CLIPImageProcessor]
|
||||
safety_checker: Optional[IFSafetyChecker]
|
||||
|
||||
watermarker: Optional[IFWatermarker]
|
||||
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
image_noising_scheduler: DDPMScheduler,
|
||||
safety_checker: Optional[IFSafetyChecker],
|
||||
feature_extractor: Optional[CLIPImageProcessor],
|
||||
watermarker: Optional[IFWatermarker],
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the IF license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
if unet.config.in_channels != 6:
|
||||
logger.warn(
|
||||
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
image_noising_scheduler=image_noising_scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
watermarker=watermarker,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.text_encoder,
|
||||
self.unet,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
|
||||
if self.text_encoder is not None:
|
||||
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
||||
|
||||
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
||||
# previous model. This will cause both models to be present on the device at the same time.
|
||||
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
||||
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
||||
# the GPU.
|
||||
self.text_encoder_offload_hook = hook
|
||||
|
||||
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
||||
|
||||
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
||||
self.unet_offload_hook = hook
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
||||
max_length = 77
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
dtype = self.unet.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, nsfw_detected, watermark_detected = self.safety_checker(
|
||||
images=image,
|
||||
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
||||
)
|
||||
else:
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
batch_size,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
|
||||
)
|
||||
|
||||
if isinstance(image, list):
|
||||
check_image_type = image[0]
|
||||
else:
|
||||
check_image_type = image
|
||||
|
||||
if (
|
||||
not isinstance(check_image_type, torch.Tensor)
|
||||
and not isinstance(check_image_type, PIL.Image.Image)
|
||||
and not isinstance(check_image_type, np.ndarray)
|
||||
):
|
||||
raise ValueError(
|
||||
"`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
|
||||
f" {type(check_image_type)}"
|
||||
)
|
||||
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image_batch_size = image.shape[0]
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image_batch_size = 1
|
||||
elif isinstance(image, np.ndarray):
|
||||
image_batch_size = image.shape[0]
|
||||
else:
|
||||
assert False
|
||||
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images
|
||||
def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
|
||||
return intermediate_images
|
||||
|
||||
def preprocess_image(self, image, num_images_per_prompt, device):
|
||||
if not isinstance(image, torch.Tensor) and not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
|
||||
|
||||
image = np.stack(image, axis=0) # to np
|
||||
torch.from_numpy(image.transpose(0, 3, 1, 2))
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
image = np.stack(image, axis=0) # to np
|
||||
if image.ndim == 5:
|
||||
image = image[0]
|
||||
|
||||
image = torch.from_numpy(image.transpose(0, 3, 1, 2))
|
||||
elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
|
||||
dims = image[0].ndim
|
||||
|
||||
if dims == 3:
|
||||
image = torch.stack(image, dim=0)
|
||||
elif dims == 4:
|
||||
image = torch.concat(image, dim=0)
|
||||
else:
|
||||
raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
|
||||
|
||||
image = image.to(device=device, dtype=self.unet.dtype)
|
||||
|
||||
image = image.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
noise_level: int = 250,
|
||||
clean_caption: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
|
||||
The image to be upscaled.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
noise_level (`int`, *optional*, defaults to 250):
|
||||
The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
||||
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
||||
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
||||
or watermarked content, according to the `safety_checker`.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
batch_size,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
|
||||
height = self.unet.config.sample_size
|
||||
width = self.unet.config.sample_size
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if timesteps is not None:
|
||||
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare intermediate images
|
||||
num_channels = self.unet.config.in_channels // 2
|
||||
intermediate_images = self.prepare_intermediate_images(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare upscaled image and noise level
|
||||
image = self.preprocess_image(image, num_images_per_prompt, device)
|
||||
upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
|
||||
|
||||
noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
|
||||
noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
|
||||
upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_level = torch.cat([noise_level] * 2)
|
||||
|
||||
# HACK: see comment in `enable_model_cpu_offload`
|
||||
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
||||
self.text_encoder_offload_hook.offload()
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
model_input = torch.cat([intermediate_images, upscaled], dim=1)
|
||||
|
||||
model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
class_labels=noise_level,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, intermediate_images)
|
||||
|
||||
image = intermediate_images
|
||||
|
||||
if output_type == "pil":
|
||||
# 9. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 10. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# 11. Convert to PIL
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# 12. Apply watermark
|
||||
if self.watermarker is not None:
|
||||
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
|
||||
elif output_type == "pt":
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
else:
|
||||
# 9. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 10. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, nsfw_detected, watermark_detected)
|
||||
|
||||
return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
|
||||
59
src/diffusers/pipelines/deepfloyd_if/safety_checker.py
Normal file
59
src/diffusers/pipelines/deepfloyd_if/safety_checker.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel
|
||||
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class IFSafetyChecker(PreTrainedModel):
|
||||
config_class = CLIPConfig
|
||||
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = CLIPVisionModelWithProjection(config.vision_config)
|
||||
|
||||
self.p_head = nn.Linear(config.vision_config.projection_dim, 1)
|
||||
self.w_head = nn.Linear(config.vision_config.projection_dim, 1)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5):
|
||||
image_embeds = self.vision_model(clip_input)[0]
|
||||
|
||||
nsfw_detected = self.p_head(image_embeds)
|
||||
nsfw_detected = nsfw_detected.flatten()
|
||||
nsfw_detected = nsfw_detected > p_threshold
|
||||
nsfw_detected = nsfw_detected.tolist()
|
||||
|
||||
if any(nsfw_detected):
|
||||
logger.warning(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
|
||||
" Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
for idx, nsfw_detected_ in enumerate(nsfw_detected):
|
||||
if nsfw_detected_:
|
||||
images[idx] = np.zeros(images[idx].shape)
|
||||
|
||||
watermark_detected = self.w_head(image_embeds)
|
||||
watermark_detected = watermark_detected.flatten()
|
||||
watermark_detected = watermark_detected > w_threshold
|
||||
watermark_detected = watermark_detected.tolist()
|
||||
|
||||
if any(watermark_detected):
|
||||
logger.warning(
|
||||
"Potential watermarked content was detected in one or more images. A black image will be returned instead."
|
||||
" Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
for idx, watermark_detected_ in enumerate(watermark_detected):
|
||||
if watermark_detected_:
|
||||
images[idx] = np.zeros(images[idx].shape)
|
||||
|
||||
return images, nsfw_detected, watermark_detected
|
||||
579
src/diffusers/pipelines/deepfloyd_if/timesteps.py
Normal file
579
src/diffusers/pipelines/deepfloyd_if/timesteps.py
Normal file
@@ -0,0 +1,579 @@
|
||||
fast27_timesteps = [
|
||||
999,
|
||||
800,
|
||||
799,
|
||||
600,
|
||||
599,
|
||||
500,
|
||||
400,
|
||||
399,
|
||||
377,
|
||||
355,
|
||||
333,
|
||||
311,
|
||||
288,
|
||||
266,
|
||||
244,
|
||||
222,
|
||||
200,
|
||||
199,
|
||||
177,
|
||||
155,
|
||||
133,
|
||||
111,
|
||||
88,
|
||||
66,
|
||||
44,
|
||||
22,
|
||||
0,
|
||||
]
|
||||
|
||||
smart27_timesteps = [
|
||||
999,
|
||||
976,
|
||||
952,
|
||||
928,
|
||||
905,
|
||||
882,
|
||||
858,
|
||||
857,
|
||||
810,
|
||||
762,
|
||||
715,
|
||||
714,
|
||||
572,
|
||||
429,
|
||||
428,
|
||||
286,
|
||||
285,
|
||||
238,
|
||||
190,
|
||||
143,
|
||||
142,
|
||||
118,
|
||||
95,
|
||||
71,
|
||||
47,
|
||||
24,
|
||||
0,
|
||||
]
|
||||
|
||||
smart50_timesteps = [
|
||||
999,
|
||||
988,
|
||||
977,
|
||||
966,
|
||||
955,
|
||||
944,
|
||||
933,
|
||||
922,
|
||||
911,
|
||||
900,
|
||||
899,
|
||||
879,
|
||||
859,
|
||||
840,
|
||||
820,
|
||||
800,
|
||||
799,
|
||||
766,
|
||||
733,
|
||||
700,
|
||||
699,
|
||||
650,
|
||||
600,
|
||||
599,
|
||||
500,
|
||||
499,
|
||||
400,
|
||||
399,
|
||||
350,
|
||||
300,
|
||||
299,
|
||||
266,
|
||||
233,
|
||||
200,
|
||||
199,
|
||||
179,
|
||||
159,
|
||||
140,
|
||||
120,
|
||||
100,
|
||||
99,
|
||||
88,
|
||||
77,
|
||||
66,
|
||||
55,
|
||||
44,
|
||||
33,
|
||||
22,
|
||||
11,
|
||||
0,
|
||||
]
|
||||
|
||||
smart100_timesteps = [
|
||||
999,
|
||||
995,
|
||||
992,
|
||||
989,
|
||||
985,
|
||||
981,
|
||||
978,
|
||||
975,
|
||||
971,
|
||||
967,
|
||||
964,
|
||||
961,
|
||||
957,
|
||||
956,
|
||||
951,
|
||||
947,
|
||||
942,
|
||||
937,
|
||||
933,
|
||||
928,
|
||||
923,
|
||||
919,
|
||||
914,
|
||||
913,
|
||||
908,
|
||||
903,
|
||||
897,
|
||||
892,
|
||||
887,
|
||||
881,
|
||||
876,
|
||||
871,
|
||||
870,
|
||||
864,
|
||||
858,
|
||||
852,
|
||||
846,
|
||||
840,
|
||||
834,
|
||||
828,
|
||||
827,
|
||||
820,
|
||||
813,
|
||||
806,
|
||||
799,
|
||||
792,
|
||||
785,
|
||||
784,
|
||||
777,
|
||||
770,
|
||||
763,
|
||||
756,
|
||||
749,
|
||||
742,
|
||||
741,
|
||||
733,
|
||||
724,
|
||||
716,
|
||||
707,
|
||||
699,
|
||||
698,
|
||||
688,
|
||||
677,
|
||||
666,
|
||||
656,
|
||||
655,
|
||||
645,
|
||||
634,
|
||||
623,
|
||||
613,
|
||||
612,
|
||||
598,
|
||||
584,
|
||||
570,
|
||||
569,
|
||||
555,
|
||||
541,
|
||||
527,
|
||||
526,
|
||||
505,
|
||||
484,
|
||||
483,
|
||||
462,
|
||||
440,
|
||||
439,
|
||||
396,
|
||||
395,
|
||||
352,
|
||||
351,
|
||||
308,
|
||||
307,
|
||||
264,
|
||||
263,
|
||||
220,
|
||||
219,
|
||||
176,
|
||||
132,
|
||||
88,
|
||||
44,
|
||||
0,
|
||||
]
|
||||
|
||||
smart185_timesteps = [
|
||||
999,
|
||||
997,
|
||||
995,
|
||||
992,
|
||||
990,
|
||||
988,
|
||||
986,
|
||||
984,
|
||||
981,
|
||||
979,
|
||||
977,
|
||||
975,
|
||||
972,
|
||||
970,
|
||||
968,
|
||||
966,
|
||||
964,
|
||||
961,
|
||||
959,
|
||||
957,
|
||||
956,
|
||||
954,
|
||||
951,
|
||||
949,
|
||||
946,
|
||||
944,
|
||||
941,
|
||||
939,
|
||||
936,
|
||||
934,
|
||||
931,
|
||||
929,
|
||||
926,
|
||||
924,
|
||||
921,
|
||||
919,
|
||||
916,
|
||||
914,
|
||||
913,
|
||||
910,
|
||||
907,
|
||||
905,
|
||||
902,
|
||||
899,
|
||||
896,
|
||||
893,
|
||||
891,
|
||||
888,
|
||||
885,
|
||||
882,
|
||||
879,
|
||||
877,
|
||||
874,
|
||||
871,
|
||||
870,
|
||||
867,
|
||||
864,
|
||||
861,
|
||||
858,
|
||||
855,
|
||||
852,
|
||||
849,
|
||||
846,
|
||||
843,
|
||||
840,
|
||||
837,
|
||||
834,
|
||||
831,
|
||||
828,
|
||||
827,
|
||||
824,
|
||||
821,
|
||||
817,
|
||||
814,
|
||||
811,
|
||||
808,
|
||||
804,
|
||||
801,
|
||||
798,
|
||||
795,
|
||||
791,
|
||||
788,
|
||||
785,
|
||||
784,
|
||||
780,
|
||||
777,
|
||||
774,
|
||||
770,
|
||||
766,
|
||||
763,
|
||||
760,
|
||||
756,
|
||||
752,
|
||||
749,
|
||||
746,
|
||||
742,
|
||||
741,
|
||||
737,
|
||||
733,
|
||||
730,
|
||||
726,
|
||||
722,
|
||||
718,
|
||||
714,
|
||||
710,
|
||||
707,
|
||||
703,
|
||||
699,
|
||||
698,
|
||||
694,
|
||||
690,
|
||||
685,
|
||||
681,
|
||||
677,
|
||||
673,
|
||||
669,
|
||||
664,
|
||||
660,
|
||||
656,
|
||||
655,
|
||||
650,
|
||||
646,
|
||||
641,
|
||||
636,
|
||||
632,
|
||||
627,
|
||||
622,
|
||||
618,
|
||||
613,
|
||||
612,
|
||||
607,
|
||||
602,
|
||||
596,
|
||||
591,
|
||||
586,
|
||||
580,
|
||||
575,
|
||||
570,
|
||||
569,
|
||||
563,
|
||||
557,
|
||||
551,
|
||||
545,
|
||||
539,
|
||||
533,
|
||||
527,
|
||||
526,
|
||||
519,
|
||||
512,
|
||||
505,
|
||||
498,
|
||||
491,
|
||||
484,
|
||||
483,
|
||||
474,
|
||||
466,
|
||||
457,
|
||||
449,
|
||||
440,
|
||||
439,
|
||||
428,
|
||||
418,
|
||||
407,
|
||||
396,
|
||||
395,
|
||||
381,
|
||||
366,
|
||||
352,
|
||||
351,
|
||||
330,
|
||||
308,
|
||||
307,
|
||||
286,
|
||||
264,
|
||||
263,
|
||||
242,
|
||||
220,
|
||||
219,
|
||||
176,
|
||||
175,
|
||||
132,
|
||||
131,
|
||||
88,
|
||||
44,
|
||||
0,
|
||||
]
|
||||
|
||||
super27_timesteps = [
|
||||
999,
|
||||
991,
|
||||
982,
|
||||
974,
|
||||
966,
|
||||
958,
|
||||
950,
|
||||
941,
|
||||
933,
|
||||
925,
|
||||
916,
|
||||
908,
|
||||
900,
|
||||
899,
|
||||
874,
|
||||
850,
|
||||
825,
|
||||
800,
|
||||
799,
|
||||
700,
|
||||
600,
|
||||
500,
|
||||
400,
|
||||
300,
|
||||
200,
|
||||
100,
|
||||
0,
|
||||
]
|
||||
|
||||
super40_timesteps = [
|
||||
999,
|
||||
992,
|
||||
985,
|
||||
978,
|
||||
971,
|
||||
964,
|
||||
957,
|
||||
949,
|
||||
942,
|
||||
935,
|
||||
928,
|
||||
921,
|
||||
914,
|
||||
907,
|
||||
900,
|
||||
899,
|
||||
879,
|
||||
859,
|
||||
840,
|
||||
820,
|
||||
800,
|
||||
799,
|
||||
766,
|
||||
733,
|
||||
700,
|
||||
699,
|
||||
650,
|
||||
600,
|
||||
599,
|
||||
500,
|
||||
499,
|
||||
400,
|
||||
399,
|
||||
300,
|
||||
299,
|
||||
200,
|
||||
199,
|
||||
100,
|
||||
99,
|
||||
0,
|
||||
]
|
||||
|
||||
super100_timesteps = [
|
||||
999,
|
||||
996,
|
||||
992,
|
||||
989,
|
||||
985,
|
||||
982,
|
||||
979,
|
||||
975,
|
||||
972,
|
||||
968,
|
||||
965,
|
||||
961,
|
||||
958,
|
||||
955,
|
||||
951,
|
||||
948,
|
||||
944,
|
||||
941,
|
||||
938,
|
||||
934,
|
||||
931,
|
||||
927,
|
||||
924,
|
||||
920,
|
||||
917,
|
||||
914,
|
||||
910,
|
||||
907,
|
||||
903,
|
||||
900,
|
||||
899,
|
||||
891,
|
||||
884,
|
||||
876,
|
||||
869,
|
||||
861,
|
||||
853,
|
||||
846,
|
||||
838,
|
||||
830,
|
||||
823,
|
||||
815,
|
||||
808,
|
||||
800,
|
||||
799,
|
||||
788,
|
||||
777,
|
||||
766,
|
||||
755,
|
||||
744,
|
||||
733,
|
||||
722,
|
||||
711,
|
||||
700,
|
||||
699,
|
||||
688,
|
||||
677,
|
||||
666,
|
||||
655,
|
||||
644,
|
||||
633,
|
||||
622,
|
||||
611,
|
||||
600,
|
||||
599,
|
||||
585,
|
||||
571,
|
||||
557,
|
||||
542,
|
||||
528,
|
||||
514,
|
||||
500,
|
||||
499,
|
||||
485,
|
||||
471,
|
||||
457,
|
||||
442,
|
||||
428,
|
||||
414,
|
||||
400,
|
||||
399,
|
||||
379,
|
||||
359,
|
||||
340,
|
||||
320,
|
||||
300,
|
||||
299,
|
||||
279,
|
||||
259,
|
||||
240,
|
||||
220,
|
||||
200,
|
||||
199,
|
||||
166,
|
||||
133,
|
||||
100,
|
||||
99,
|
||||
66,
|
||||
33,
|
||||
0,
|
||||
]
|
||||
46
src/diffusers/pipelines/deepfloyd_if/watermark.py
Normal file
46
src/diffusers/pipelines/deepfloyd_if/watermark.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import List
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ...configuration_utils import ConfigMixin
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import PIL_INTERPOLATION
|
||||
|
||||
|
||||
class IFWatermarker(ModelMixin, ConfigMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer("watermark_image", torch.zeros((62, 62, 4)))
|
||||
self.watermark_image_as_pil = None
|
||||
|
||||
def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None):
|
||||
# copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287
|
||||
|
||||
h = images[0].height
|
||||
w = images[0].width
|
||||
|
||||
sample_size = sample_size or h
|
||||
|
||||
coef = min(h / sample_size, w / sample_size)
|
||||
img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)
|
||||
|
||||
S1, S2 = 1024**2, img_w * img_h
|
||||
K = (S2 / S1) ** 0.5
|
||||
wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K)
|
||||
|
||||
if self.watermark_image_as_pil is None:
|
||||
watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy()
|
||||
watermark_image = Image.fromarray(watermark_image, mode="RGBA")
|
||||
self.watermark_image_as_pil = watermark_image
|
||||
|
||||
wm_img = self.watermark_image_as_pil.resize(
|
||||
(wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None
|
||||
)
|
||||
|
||||
for pil_img in images:
|
||||
pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1])
|
||||
|
||||
return images
|
||||
@@ -30,7 +30,6 @@ import PIL
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, model_info, snapshot_download
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import diffusers
|
||||
@@ -56,6 +55,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
|
||||
|
||||
@@ -623,7 +623,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
|
||||
return hasattr(module, "_hf_hook") and not isinstance(
|
||||
module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
|
||||
)
|
||||
|
||||
def module_is_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
||||
@@ -653,7 +655,20 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for module in modules:
|
||||
module.to(torch_device, torch_dtype)
|
||||
is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
|
||||
|
||||
if is_loaded_in_8bit and torch_dtype is not None:
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
|
||||
)
|
||||
|
||||
if is_loaded_in_8bit and torch_device is not None:
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
|
||||
)
|
||||
else:
|
||||
module.to(torch_device, torch_dtype)
|
||||
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
@@ -887,6 +902,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
# 2. Define which model components should load variants
|
||||
# We retrieve the information by matching whether variant
|
||||
# model checkpoints exist in the subfolders
|
||||
@@ -1204,12 +1222,19 @@ class DiffusionPipeline(ConfigMixin):
|
||||
)
|
||||
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
# remove ignored filenames
|
||||
model_filenames = set(model_filenames) - set(ignore_filenames)
|
||||
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
||||
|
||||
# if the whole pipeline is cached we don't have to ping the Hub
|
||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||
version.parse(__version__).base_version
|
||||
@@ -1370,16 +1395,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
if images.shape[-1] == 1:
|
||||
# special case for grayscale (single channel) images
|
||||
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
||||
else:
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
return numpy_to_pil(images)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
|
||||
@@ -56,7 +56,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
scheduler: Any,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
watermarker=None,
|
||||
max_noise_level=max_noise_level,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
@@ -13,18 +13,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -76,6 +77,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
_optional_components = ["watermarker", "safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -85,12 +87,16 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
unet: UNet2DConditionModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: Optional[Any] = None,
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
watermarker: Optional[Any] = None,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(vae, "config"):
|
||||
# check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
|
||||
if hasattr(
|
||||
vae, "config"
|
||||
): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
|
||||
is_vae_scaling_factor_set_to_0_08333 = (
|
||||
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
|
||||
)
|
||||
@@ -113,6 +119,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
watermarker=watermarker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(max_noise_level=max_noise_level)
|
||||
|
||||
@@ -178,6 +187,23 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, nsfw_detected, watermark_detected = self.safety_checker(
|
||||
images=image,
|
||||
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
||||
)
|
||||
else:
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
@@ -678,10 +704,19 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
# 11. Convert to PIL
|
||||
# has_nsfw_concept = False
|
||||
if output_type == "pil":
|
||||
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
# 11. Apply watermark
|
||||
if self.watermarker is not None:
|
||||
image = self.watermarker.apply_watermark(image)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
@@ -15,7 +15,7 @@ from ...models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
from ...models.dual_transformer_2d import DualTransformer2DModel
|
||||
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import logging
|
||||
@@ -183,11 +183,16 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to None):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
time_embedding_type (`str`, *optional*, default to `positional`):
|
||||
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
||||
time_embedding_dim (`int`, *optional*, default to `None`):
|
||||
An optional override for the dimension of the projected time embedding.
|
||||
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
||||
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
||||
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
||||
@@ -246,12 +251,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_skip_time_act: bool = False,
|
||||
resnet_out_scale_factor: int = 1.0,
|
||||
time_embedding_type: str = "positional",
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
time_embedding_act_fn: Optional[str] = None,
|
||||
timestep_post_act: Optional[str] = None,
|
||||
time_cond_proj_dim: Optional[int] = None,
|
||||
@@ -261,6 +268,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
class_embeddings_concat: bool = False,
|
||||
mid_block_only_cross_attention: Optional[bool] = None,
|
||||
cross_attention_norm: Optional[str] = None,
|
||||
addition_embed_type_num_heads=64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -311,7 +319,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
time_embed_dim = block_out_channels[0] * 2
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
||||
if time_embed_dim % 2 != 0:
|
||||
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
@@ -319,7 +327,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
timestep_input_dim = time_embed_dim
|
||||
elif time_embedding_type == "positional":
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
@@ -370,6 +378,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
|
||||
|
||||
if time_embedding_act_fn is None:
|
||||
self.time_embed_act = None
|
||||
elif time_embedding_act_fn == "swish":
|
||||
@@ -781,6 +801,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
emb = emb + aug_emb
|
||||
|
||||
if self.time_embed_act is not None:
|
||||
emb = self.time_embed_act(emb)
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ from .hub_utils import (
|
||||
http_user_agent,
|
||||
)
|
||||
from .import_utils import (
|
||||
BACKENDS_MAPPING,
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
USE_JAX,
|
||||
@@ -53,7 +54,9 @@ from .import_utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bs4_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_inflect_available,
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
@@ -76,7 +79,7 @@ from .import_utils import (
|
||||
)
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .pil_utils import PIL_INTERPOLATION
|
||||
from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil
|
||||
from .torch_utils import is_compiled_module, randn_tensor
|
||||
|
||||
|
||||
|
||||
@@ -62,6 +62,96 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class IFImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class IFInpaintingPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class IFPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class IFSuperResolutionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -271,6 +271,23 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_compel_available = False
|
||||
|
||||
|
||||
_ftfy_available = importlib.util.find_spec("ftfy") is not None
|
||||
try:
|
||||
_ftfy_version = importlib_metadata.version("ftfy")
|
||||
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_ftfy_available = False
|
||||
|
||||
|
||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||
try:
|
||||
# importlib metadata under different name
|
||||
_bs4_version = importlib_metadata.version("beautifulsoup4")
|
||||
logger.debug(f"Successfully imported ftfy version {_bs4_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_bs4_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
@@ -347,6 +364,14 @@ def is_compel_available():
|
||||
return _compel_available
|
||||
|
||||
|
||||
def is_ftfy_available():
|
||||
return _ftfy_available
|
||||
|
||||
|
||||
def is_bs4_available():
|
||||
return _bs4_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -437,8 +462,23 @@ COMPEL_IMPORT_ERROR = """
|
||||
{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
BS4_IMPORT_ERROR = """
|
||||
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
|
||||
`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
FTFY_IMPORT_ERROR = """
|
||||
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
|
||||
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
|
||||
that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
|
||||
("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
|
||||
@@ -454,6 +494,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
|
||||
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
|
||||
("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
|
||||
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
@@ -19,3 +20,26 @@ else:
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
|
||||
|
||||
def pt_to_pil(images):
|
||||
images = (images / 2 + 0.5).clamp(0, 1)
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
images = numpy_to_pil(images)
|
||||
return images
|
||||
|
||||
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
if images.shape[-1] == 1:
|
||||
# special case for grayscale (single channel) images
|
||||
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
||||
else:
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
272
tests/pipelines/deepfloyd_if/__init__.py
Normal file
272
tests/pipelines/deepfloyd_if/__init__.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import DDPMScheduler, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnAddedKVProcessor
|
||||
from diffusers.pipelines.deepfloyd_if import IFWatermarker
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..test_pipelines_common import to_np
|
||||
|
||||
|
||||
# WARN: the hf-internal-testing/tiny-random-t5 text encoder has some non-determinism in the `save_load` tests.
|
||||
|
||||
|
||||
class IFPipelineTesterMixin:
|
||||
def _get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
sample_size=32,
|
||||
layers_per_block=1,
|
||||
block_out_channels=[32, 64],
|
||||
down_block_types=[
|
||||
"ResnetDownsampleBlock2D",
|
||||
"SimpleCrossAttnDownBlock2D",
|
||||
],
|
||||
mid_block_type="UNetMidBlock2DSimpleCrossAttn",
|
||||
up_block_types=["SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"],
|
||||
in_channels=3,
|
||||
out_channels=6,
|
||||
cross_attention_dim=32,
|
||||
encoder_hid_dim=32,
|
||||
attention_head_dim=8,
|
||||
addition_embed_type="text",
|
||||
addition_embed_type_num_heads=2,
|
||||
cross_attention_norm="group_norm",
|
||||
resnet_time_scale_shift="scale_shift",
|
||||
act_fn="gelu",
|
||||
)
|
||||
unet.set_attn_processor(AttnAddedKVProcessor()) # For reproducibility tests
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDPMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
thresholding=True,
|
||||
dynamic_thresholding_ratio=0.95,
|
||||
sample_max_value=1.0,
|
||||
prediction_type="epsilon",
|
||||
variance_type="learned_range",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
watermarker = IFWatermarker()
|
||||
|
||||
return {
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"watermarker": watermarker,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
|
||||
def _get_superresolution_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
sample_size=32,
|
||||
layers_per_block=[1, 2],
|
||||
block_out_channels=[32, 64],
|
||||
down_block_types=[
|
||||
"ResnetDownsampleBlock2D",
|
||||
"SimpleCrossAttnDownBlock2D",
|
||||
],
|
||||
mid_block_type="UNetMidBlock2DSimpleCrossAttn",
|
||||
up_block_types=["SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"],
|
||||
in_channels=6,
|
||||
out_channels=6,
|
||||
cross_attention_dim=32,
|
||||
encoder_hid_dim=32,
|
||||
attention_head_dim=8,
|
||||
addition_embed_type="text",
|
||||
addition_embed_type_num_heads=2,
|
||||
cross_attention_norm="group_norm",
|
||||
resnet_time_scale_shift="scale_shift",
|
||||
act_fn="gelu",
|
||||
class_embed_type="timestep",
|
||||
mid_block_scale_factor=1.414,
|
||||
time_embedding_act_fn="gelu",
|
||||
time_embedding_dim=32,
|
||||
)
|
||||
unet.set_attn_processor(AttnAddedKVProcessor()) # For reproducibility tests
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDPMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
thresholding=True,
|
||||
dynamic_thresholding_ratio=0.95,
|
||||
sample_max_value=1.0,
|
||||
prediction_type="epsilon",
|
||||
variance_type="learned_range",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_noising_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
watermarker = IFWatermarker()
|
||||
|
||||
return {
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"image_noising_scheduler": image_noising_scheduler,
|
||||
"watermarker": watermarker,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
|
||||
# this test is modified from the base class because if pipelines set the text encoder
|
||||
# as optional with the intention that the user is allowed to encode the prompt once
|
||||
# and then pass the embeddings directly to the pipeline. The base class test uses
|
||||
# the unmodified arguments from `self.get_dummy_inputs` which will pass the unencoded
|
||||
# prompt to the pipeline when the text encoder is set to None, throwing an error.
|
||||
# So we make the test reflect the intended usage of setting the text encoder to None.
|
||||
def _test_save_load_optional_components(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = inputs["prompt"]
|
||||
generator = inputs["generator"]
|
||||
num_inference_steps = inputs["num_inference_steps"]
|
||||
output_type = inputs["output_type"]
|
||||
|
||||
if "image" in inputs:
|
||||
image = inputs["image"]
|
||||
else:
|
||||
image = None
|
||||
|
||||
if "mask_image" in inputs:
|
||||
mask_image = inputs["mask_image"]
|
||||
else:
|
||||
mask_image = None
|
||||
|
||||
if "original_image" in inputs:
|
||||
original_image = inputs["original_image"]
|
||||
else:
|
||||
original_image = None
|
||||
|
||||
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt)
|
||||
|
||||
# inputs with prompt converted to embeddings
|
||||
inputs = {
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"negative_prompt_embeds": negative_prompt_embeds,
|
||||
"generator": generator,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"output_type": output_type,
|
||||
}
|
||||
|
||||
if image is not None:
|
||||
inputs["image"] = image
|
||||
|
||||
if mask_image is not None:
|
||||
inputs["mask_image"] = mask_image
|
||||
|
||||
if original_image is not None:
|
||||
inputs["original_image"] = original_image
|
||||
|
||||
# set all optional components to None
|
||||
for optional_component in pipe._optional_components:
|
||||
setattr(pipe, optional_component, None)
|
||||
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe_loaded.unet.set_attn_processor(AttnAddedKVProcessor()) # For reproducibility tests
|
||||
|
||||
for optional_component in pipe._optional_components:
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, optional_component) is None,
|
||||
f"`{optional_component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
generator = inputs["generator"]
|
||||
num_inference_steps = inputs["num_inference_steps"]
|
||||
output_type = inputs["output_type"]
|
||||
|
||||
# inputs with prompt converted to embeddings
|
||||
inputs = {
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"negative_prompt_embeds": negative_prompt_embeds,
|
||||
"generator": generator,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"output_type": output_type,
|
||||
}
|
||||
|
||||
if image is not None:
|
||||
inputs["image"] = image
|
||||
|
||||
if mask_image is not None:
|
||||
inputs["mask_image"] = mask_image
|
||||
|
||||
if original_image is not None:
|
||||
inputs["original_image"] = original_image
|
||||
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
# Modified from `PipelineTesterMixin` to set the attn processor as it's not serialized.
|
||||
# This should be handled in the base test and then this method can be removed.
|
||||
def _test_save_load_local(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe_loaded.unet.set_attn_processor(AttnAddedKVProcessor()) # For reproducibility tests
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
340
tests/pipelines/deepfloyd_if/test_if.py
Normal file
340
tests/pipelines/deepfloyd_if/test_if.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnAddedKVProcessor
|
||||
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
from . import IFPipelineTesterMixin
|
||||
|
||||
|
||||
@skip_mps
|
||||
class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = IFPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"width", "height", "latents"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
return self._get_dummy_components()
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self):
|
||||
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
|
||||
self._test_save_load_float16(expected_max_diff=1e-1)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_local(self):
|
||||
self._test_save_load_local()
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class IFPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_all(self):
|
||||
# if
|
||||
|
||||
pipe_1 = IFPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
|
||||
pipe_2 = IFSuperResolutionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16, text_encoder=None, tokenizer=None
|
||||
)
|
||||
|
||||
# pre compute text embeddings and remove T5 to save memory
|
||||
|
||||
pipe_1.text_encoder.to("cuda")
|
||||
|
||||
prompt_embeds, negative_prompt_embeds = pipe_1.encode_prompt("anime turtle", device="cuda")
|
||||
|
||||
del pipe_1.tokenizer
|
||||
del pipe_1.text_encoder
|
||||
gc.collect()
|
||||
|
||||
pipe_1.tokenizer = None
|
||||
pipe_1.text_encoder = None
|
||||
|
||||
pipe_1.enable_model_cpu_offload()
|
||||
pipe_2.enable_model_cpu_offload()
|
||||
|
||||
pipe_1.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
pipe_2.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
|
||||
self._test_if(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
pipe_1.remove_all_hooks()
|
||||
pipe_2.remove_all_hooks()
|
||||
|
||||
# img2img
|
||||
|
||||
pipe_1 = IFImg2ImgPipeline(**pipe_1.components)
|
||||
pipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components)
|
||||
|
||||
pipe_1.enable_model_cpu_offload()
|
||||
pipe_2.enable_model_cpu_offload()
|
||||
|
||||
pipe_1.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
pipe_2.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
|
||||
self._test_if_img2img(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
pipe_1.remove_all_hooks()
|
||||
pipe_2.remove_all_hooks()
|
||||
|
||||
# inpainting
|
||||
|
||||
pipe_1 = IFInpaintingPipeline(**pipe_1.components)
|
||||
pipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components)
|
||||
|
||||
pipe_1.enable_model_cpu_offload()
|
||||
pipe_2.enable_model_cpu_offload()
|
||||
|
||||
pipe_1.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
pipe_2.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
|
||||
self._test_if_inpainting(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
def _test_if(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds):
|
||||
# pipeline 1
|
||||
|
||||
_start_torch_memory_measurement()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipe_1(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
num_inference_steps=2,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (64, 64, 3)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes < 13 * 10**9
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if.npy"
|
||||
)
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
|
||||
# pipeline 2
|
||||
|
||||
_start_torch_memory_measurement()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
|
||||
|
||||
output = pipe_2(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
image=image,
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (256, 256, 3)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes < 4 * 10**9
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_superresolution_stage_II.npy"
|
||||
)
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
|
||||
def _test_if_img2img(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds):
|
||||
# pipeline 1
|
||||
|
||||
_start_torch_memory_measurement()
|
||||
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
output = pipe_1(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
image=image,
|
||||
num_inference_steps=2,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (64, 64, 3)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes < 10 * 10**9
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img.npy"
|
||||
)
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
|
||||
# pipeline 2
|
||||
|
||||
_start_torch_memory_measurement()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
|
||||
|
||||
output = pipe_2(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
image=image,
|
||||
original_image=original_image,
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (256, 256, 3)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes < 4 * 10**9
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img_superresolution_stage_II.npy"
|
||||
)
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
|
||||
def _test_if_inpainting(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds):
|
||||
# pipeline 1
|
||||
|
||||
_start_torch_memory_measurement()
|
||||
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
|
||||
mask_image = floats_tensor((1, 3, 64, 64), rng=random.Random(1)).to(torch_device)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipe_1(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=2,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (64, 64, 3)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes < 10 * 10**9
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting.npy"
|
||||
)
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
|
||||
# pipeline 2
|
||||
|
||||
_start_torch_memory_measurement()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
|
||||
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
|
||||
mask_image = floats_tensor((1, 3, 256, 256), rng=random.Random(1)).to(torch_device)
|
||||
|
||||
output = pipe_2(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
original_image=original_image,
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (256, 256, 3)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes < 4 * 10**9
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting_superresolution_stage_II.npy"
|
||||
)
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
|
||||
|
||||
def _start_torch_memory_measurement():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
84
tests/pipelines/deepfloyd_if/test_if_img2img.py
Normal file
84
tests/pipelines/deepfloyd_if/test_if_img2img.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import IFImg2ImgPipeline
|
||||
from diffusers.utils import floats_tensor
|
||||
from diffusers.utils.testing_utils import skip_mps, torch_device
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from . import IFPipelineTesterMixin
|
||||
|
||||
|
||||
@skip_mps
|
||||
class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = IFImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"width", "height"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
return self._get_dummy_components()
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self):
|
||||
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
|
||||
self._test_save_load_float16(expected_max_diff=1e-1)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_float16_inference(self):
|
||||
self._test_float16_inference(expected_max_diff=1e-1)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_local(self):
|
||||
self._test_save_load_local()
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import IFImg2ImgSuperResolutionPipeline
|
||||
from diffusers.utils import floats_tensor
|
||||
from diffusers.utils.testing_utils import skip_mps, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from . import IFPipelineTesterMixin
|
||||
|
||||
|
||||
@skip_mps
|
||||
class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = IFImg2ImgSuperResolutionPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"width", "height"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"})
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
return self._get_superresolution_dummy_components()
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
original_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = floats_tensor((1, 3, 16, 16), rng=random.Random(seed)).to(device)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"original_image": original_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self):
|
||||
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
|
||||
self._test_save_load_float16(expected_max_diff=1e-1)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_local(self):
|
||||
self._test_save_load_local()
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
82
tests/pipelines/deepfloyd_if/test_if_inpainting.py
Normal file
82
tests/pipelines/deepfloyd_if/test_if_inpainting.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import IFInpaintingPipeline
|
||||
from diffusers.utils import floats_tensor
|
||||
from diffusers.utils.testing_utils import skip_mps, torch_device
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from . import IFPipelineTesterMixin
|
||||
|
||||
|
||||
@skip_mps
|
||||
class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = IFInpaintingPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
return self._get_dummy_components()
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
mask_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self):
|
||||
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
|
||||
self._test_save_load_float16(expected_max_diff=1e-1)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_local(self):
|
||||
self._test_save_load_local()
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import IFInpaintingSuperResolutionPipeline
|
||||
from diffusers.utils import floats_tensor
|
||||
from diffusers.utils.testing_utils import skip_mps, torch_device
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from . import IFPipelineTesterMixin
|
||||
|
||||
|
||||
@skip_mps
|
||||
class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = IFInpaintingSuperResolutionPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"})
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
return self._get_superresolution_dummy_components()
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = floats_tensor((1, 3, 16, 16), rng=random.Random(seed)).to(device)
|
||||
original_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
mask_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"original_image": original_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self):
|
||||
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
|
||||
self._test_save_load_float16(expected_max_diff=1e-1)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_local(self):
|
||||
self._test_save_load_local()
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
77
tests/pipelines/deepfloyd_if/test_if_superresolution.py
Normal file
77
tests/pipelines/deepfloyd_if/test_if_superresolution.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import IFSuperResolutionPipeline
|
||||
from diffusers.utils import floats_tensor
|
||||
from diffusers.utils.testing_utils import skip_mps, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from . import IFPipelineTesterMixin
|
||||
|
||||
|
||||
@skip_mps
|
||||
class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = IFSuperResolutionPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"width", "height"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
return self._get_superresolution_dummy_components()
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self):
|
||||
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
|
||||
self._test_save_load_float16(expected_max_diff=1e-1)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_local(self):
|
||||
self._test_save_load_local()
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
@@ -575,6 +575,19 @@ class DownloadTests(unittest.TestCase):
|
||||
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
|
||||
assert out.shape == (1, 128, 128, 3)
|
||||
|
||||
def test_download_ignore_files(self):
|
||||
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
tmpdirname = DiffusionPipeline.download("hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files")
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a pytorch file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
|
||||
assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files)
|
||||
assert len(files) == 14
|
||||
|
||||
|
||||
class CustomPipelineTests(unittest.TestCase):
|
||||
def test_load_custom_pipeline(self):
|
||||
|
||||
@@ -339,6 +339,9 @@ class PipelineTesterMixin:
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_float16_inference(self):
|
||||
self._test_float16_inference()
|
||||
|
||||
def _test_float16_inference(self, expected_max_diff=1e-2):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
@@ -352,10 +355,13 @@ class PipelineTesterMixin:
|
||||
output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
|
||||
self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.")
|
||||
self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self):
|
||||
self._test_save_load_float16()
|
||||
|
||||
def _test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
components = self.get_dummy_components()
|
||||
for name, module in components.items():
|
||||
if hasattr(module, "half"):
|
||||
@@ -384,7 +390,9 @@ class PipelineTesterMixin:
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.")
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
|
||||
)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||
|
||||
Reference in New Issue
Block a user