mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Refactor controlnet and add img2img and inpaint (#3386)
* refactor controlnet and add img2img and inpaint * First draft to get pipelines to work * make style * Fix more * Fix more * More tests * Fix more * Make inpainting work * make style and more tests * Apply suggestions from code review * up * make style * Fix imports * Fix more * Fix more * Improve examples * add test * Make sure import is correctly deprecated * Make sure everything works in compile mode * make sure authorship is correctly attributed
This commit is contained in:
committed by
GitHub
parent
9d44e2fb66
commit
886575ee43
@@ -148,6 +148,8 @@
|
||||
title: Audio Diffusion
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/controlnet
|
||||
title: ControlNet
|
||||
- local: api/pipelines/cycle_diffusion
|
||||
title: Cycle Diffusion
|
||||
- local: api/pipelines/dance_diffusion
|
||||
@@ -203,8 +205,6 @@
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/stable_diffusion/panorama
|
||||
title: MultiDiffusion Panorama
|
||||
- local: api/pipelines/stable_diffusion/controlnet
|
||||
title: Text-to-Image Generation with ControlNet Conditioning
|
||||
- local: api/pipelines/stable_diffusion/model_editing
|
||||
title: Text-to-Image Model Editing
|
||||
- local: api/pipelines/stable_diffusion/diffedit
|
||||
|
||||
@@ -22,7 +22,7 @@ The abstract of the paper is the following:
|
||||
|
||||
*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*
|
||||
|
||||
This model was contributed by the amazing community contributor [takuma104](https://huggingface.co/takuma104) ❤️ .
|
||||
This model was contributed by the community contributor [takuma104](https://huggingface.co/takuma104) ❤️ .
|
||||
|
||||
Resources:
|
||||
|
||||
@@ -33,7 +33,9 @@ Resources:
|
||||
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
|
||||
| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
|
||||
| [StableDiffusionControlNetImg2ImgPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py) | *Image-to-Image Generation with ControlNet Conditioning* |
|
||||
| [StableDiffusionControlNetInpaintPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_controlnet_inpaint.py) | *Inpainting Generation with ControlNet Conditioning* |
|
||||
|
||||
## Usage example
|
||||
|
||||
@@ -301,21 +303,22 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
|
||||
|
||||
### ControlNet v1.1
|
||||
|
||||
| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
|
||||
|---|---|---|---|
|
||||
|[lllyasviel/control_v11p_sd15_canny](https://huggingface.co/lllyasviel/control_v11p_sd15_canny)<br/> *Trained with canny edge detection* | A monochrome image with white edges on a black background.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11e_sd15_ip2p](https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p)<br/> *Trained with pixel to pixel instruction* | No condition .|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint)<br/> Trained with image inpainting | No condition.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/output.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/output.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_mlsd](https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd)<br/> Trained with multi-level line segment detection | An image with annotated line segments.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11f1p_sd15_depth](https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth)<br/> Trained with depth estimation | An image with depth information, usually represented as a grayscale image.|<a href="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_normalbae](https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae)<br/> Trained with surface normal estimation | An image with surface normal information, usually represented as a color-coded image.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_seg](https://huggingface.co/lllyasviel/control_v11p_sd15_seg)<br/> Trained with image segmentation | An image with segmented regions, usually represented as a color-coded image.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_lineart](https://huggingface.co/lllyasviel/control_v11p_sd15_lineart)<br/> Trained with line art generation | An image with line art, usually black lines on a white background.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15s2_lineart_anime](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)<br/> Trained with anime line art generation | An image with anime-style line art.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)<br/> Trained with human pose estimation | An image with human poses, usually represented as a set of keypoints or skeletons.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_scribble](https://huggingface.co/lllyasviel/control_v11p_sd15_scribble)<br/> Trained with scribble-based image generation | An image with scribbles, usually random or user-drawn strokes.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_softedge](https://huggingface.co/lllyasviel/control_v11p_sd15_softedge)<br/> Trained with soft edge image generation | An image with soft edges, usually to create a more painterly or artistic effect.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11e_sd15_shuffle](https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle)<br/> Trained with image shuffling | An image with shuffled patches or regions.|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/image_out.png"/></a>|
|
||||
| Model Name | Control Image Overview| Condition Image | Control Image Example | Generated Image Example |
|
||||
|---|---|---|---|---|
|
||||
|[lllyasviel/control_v11p_sd15_canny](https://huggingface.co/lllyasviel/control_v11p_sd15_canny)<br/> | *Trained with canny edge detection* | A monochrome image with white edges on a black background.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11e_sd15_ip2p](https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p)<br/> | *Trained with pixel to pixel instruction* | No condition .|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint)<br/> | Trained with image inpainting | No condition.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/output.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/output.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_mlsd](https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd)<br/> | Trained with multi-level line segment detection | An image with annotated line segments.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11f1p_sd15_depth](https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth)<br/> | Trained with depth estimation | An image with depth information, usually represented as a grayscale image.|<a href="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_normalbae](https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae)<br/> | Trained with surface normal estimation | An image with surface normal information, usually represented as a color-coded image.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_seg](https://huggingface.co/lllyasviel/control_v11p_sd15_seg)<br/> | Trained with image segmentation | An image with segmented regions, usually represented as a color-coded image.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_lineart](https://huggingface.co/lllyasviel/control_v11p_sd15_lineart)<br/> | Trained with line art generation | An image with line art, usually black lines on a white background.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15s2_lineart_anime](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)<br/> | Trained with anime line art generation | An image with anime-style line art.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)<br/> | Trained with human pose estimation | An image with human poses, usually represented as a set of keypoints or skeletons.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_scribble](https://huggingface.co/lllyasviel/control_v11p_sd15_scribble)<br/> | Trained with scribble-based image generation | An image with scribbles, usually random or user-drawn strokes.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_softedge](https://huggingface.co/lllyasviel/control_v11p_sd15_softedge)<br/> | Trained with soft edge image generation | An image with soft edges, usually to create a more painterly or artistic effect.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11e_sd15_shuffle](https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle)<br/> | Trained with image shuffling | An image with shuffled patches or regions.|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11f1e_sd15_tile](https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile)<br/> | Trained with image tiling | A blurry image or part of an image .|<a href="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/main/images/original.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/main/images/original.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/main/images/output.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/main/images/output.png"/></a>|
|
||||
|
||||
## StableDiffusionControlNetPipeline
|
||||
[[autodoc]] StableDiffusionControlNetPipeline
|
||||
@@ -329,6 +332,30 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
|
||||
## StableDiffusionControlNetImg2ImgPipeline
|
||||
[[autodoc]] StableDiffusionControlNetImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
|
||||
## StableDiffusionControlNetInpaintPipeline
|
||||
[[autodoc]] StableDiffusionControlNetInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
|
||||
## FlaxStableDiffusionControlNetPipeline
|
||||
[[autodoc]] FlaxStableDiffusionControlNetPipeline
|
||||
- all
|
||||
@@ -46,7 +46,7 @@ available a colab notebook to directly try them out.
|
||||
|---|---|:---:|:---:|
|
||||
| [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
|
||||
| [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
|
||||
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
|
||||
| [controlnet](./api/pipelines/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
|
||||
| [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
|
||||
@@ -53,7 +53,7 @@ The library has three main components:
|
||||
|---|---|:---:|
|
||||
| [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
|
||||
| [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation |
|
||||
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
|
||||
| [controlnet](./api/pipelines/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
|
||||
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
|
||||
@@ -132,6 +132,8 @@ else:
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
|
||||
@@ -17,3 +17,13 @@
|
||||
# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
|
||||
|
||||
from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
|
||||
from .utils import deprecate
|
||||
|
||||
|
||||
deprecate(
|
||||
"pipelines_utils",
|
||||
"0.22.0",
|
||||
"Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
|
||||
standard_warn=False,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
@@ -44,6 +44,11 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
@@ -58,7 +63,6 @@ else:
|
||||
from .stable_diffusion import (
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
@@ -133,8 +137,8 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
from .stable_diffusion import (
|
||||
FlaxStableDiffusionControlNetPipeline,
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
|
||||
22
src/diffusers/pipelines/controlnet/__init__.py
Normal file
22
src/diffusers/pipelines/controlnet/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
66
src/diffusers/pipelines/controlnet/multicontrolnet.py
Normal file
66
src/diffusers/pipelines/controlnet/multicontrolnet.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class MultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
|
||||
|
||||
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `ControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`ControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
down_samples, mid_sample = controlnet(
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
image,
|
||||
scale,
|
||||
class_labels,
|
||||
timestep_cond,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
guess_mode,
|
||||
return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
1035
src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Normal file
1035
src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Normal file
File diff suppressed because it is too large
Load Diff
1113
src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
Normal file
1113
src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
Normal file
File diff suppressed because it is too large
Load Diff
1228
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Normal file
1228
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Normal file
File diff suppressed because it is too large
Load Diff
537
src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
Normal file
537
src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
Normal file
@@ -0,0 +1,537 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training.common_utils import shard
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
)
|
||||
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring
|
||||
from ..pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
from ..stable_diffusion import FlaxStableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||
DEBUG = False
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import jax
|
||||
>>> import numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from flax.jax_utils import replicate
|
||||
>>> from flax.training.common_utils import shard
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
|
||||
|
||||
|
||||
>>> def image_grid(imgs, rows, cols):
|
||||
... w, h = imgs[0].size
|
||||
... grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
... for i, img in enumerate(imgs):
|
||||
... grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
... return grid
|
||||
|
||||
|
||||
>>> def create_key(seed=0):
|
||||
... return jax.random.PRNGKey(seed)
|
||||
|
||||
|
||||
>>> rng = create_key(0)
|
||||
|
||||
>>> # get canny image
|
||||
>>> canny_image = load_image(
|
||||
... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
|
||||
... )
|
||||
|
||||
>>> prompts = "best quality, extremely detailed"
|
||||
>>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
>>> # load control net and stable diffusion v1-5
|
||||
>>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
||||
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
|
||||
... )
|
||||
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
|
||||
... )
|
||||
>>> params["controlnet"] = controlnet_params
|
||||
|
||||
>>> num_samples = jax.device_count()
|
||||
>>> rng = jax.random.split(rng, jax.device_count())
|
||||
|
||||
>>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
||||
>>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
||||
>>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
||||
|
||||
>>> p_params = replicate(params)
|
||||
>>> prompt_ids = shard(prompt_ids)
|
||||
>>> negative_prompt_ids = shard(negative_prompt_ids)
|
||||
>>> processed_image = shard(processed_image)
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt_ids=prompt_ids,
|
||||
... image=processed_image,
|
||||
... params=p_params,
|
||||
... prng_seed=rng,
|
||||
... num_inference_steps=50,
|
||||
... neg_prompt_ids=negative_prompt_ids,
|
||||
... jit=True,
|
||||
... ).images
|
||||
|
||||
>>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
|
||||
>>> output_images = image_grid(output_images, num_samples // 4, 4)
|
||||
>>> output_images.save("generated_image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance.
|
||||
|
||||
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`FlaxAutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`FlaxCLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
controlnet ([`FlaxControlNetModel`]:
|
||||
Provides additional conditioning to the unet during the denoising process.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
|
||||
[`FlaxDPMSolverMultistepScheduler`].
|
||||
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: FlaxAutoencoderKL,
|
||||
text_encoder: FlaxCLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: FlaxUNet2DConditionModel,
|
||||
controlnet: FlaxControlNetModel,
|
||||
scheduler: Union[
|
||||
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
|
||||
],
|
||||
safety_checker: FlaxStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def prepare_text_inputs(self, prompt: Union[str, List[str]]):
|
||||
if not isinstance(prompt, (str, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
return text_input.input_ids
|
||||
|
||||
def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
|
||||
if not isinstance(image, (Image.Image, list)):
|
||||
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
image = [image]
|
||||
|
||||
processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
|
||||
|
||||
return processed_images
|
||||
|
||||
def _get_has_nsfw_concepts(self, features, params):
|
||||
has_nsfw_concepts = self.safety_checker(features, params)
|
||||
return has_nsfw_concepts
|
||||
|
||||
def _run_safety_checker(self, images, safety_model_params, jit=False):
|
||||
# safety_model_params should already be replicated when jit is True
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
|
||||
|
||||
if jit:
|
||||
features = shard(features)
|
||||
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
|
||||
has_nsfw_concepts = unshard(has_nsfw_concepts)
|
||||
safety_model_params = unreplicate(safety_model_params)
|
||||
else:
|
||||
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
|
||||
|
||||
images_was_copied = False
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
if not images_was_copied:
|
||||
images_was_copied = True
|
||||
images = images.copy()
|
||||
|
||||
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
warnings.warn(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned"
|
||||
" instead. Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
latents: Optional[jnp.array] = None,
|
||||
neg_prompt_ids: Optional[jnp.array] = None,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
):
|
||||
height, width = image.shape[-2:]
|
||||
if height % 64 != 0 or width % 64 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
|
||||
|
||||
# get prompt text embeddings
|
||||
prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
|
||||
|
||||
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
batch_size = prompt_ids.shape[0]
|
||||
|
||||
max_length = prompt_ids.shape[-1]
|
||||
|
||||
if neg_prompt_ids is None:
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
).input_ids
|
||||
else:
|
||||
uncond_input = neg_prompt_ids
|
||||
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
image = jnp.concatenate([image] * 2)
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.config.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if latents is None:
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
def loop_body(step, args):
|
||||
latents, scheduler_state = args
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
latents_input = jnp.concatenate([latents] * 2)
|
||||
|
||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||
|
||||
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet.apply(
|
||||
{"params": params["controlnet"]},
|
||||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=context,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet.apply(
|
||||
{"params": params["unet"]},
|
||||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=context,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
||||
return latents, scheduler_state
|
||||
|
||||
scheduler_state = self.scheduler.set_timesteps(
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
|
||||
)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * params["scheduler"].init_noise_sigma
|
||||
|
||||
if DEBUG:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
||||
else:
|
||||
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
return image
|
||||
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||
latents: jnp.array = None,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt_ids (`jnp.array`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
image (`jnp.array`):
|
||||
Array representing the ControlNet input condition. ControlNet use this input condition to generate
|
||||
guidance to Unet.
|
||||
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
|
||||
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
if isinstance(guidance_scale, float):
|
||||
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
|
||||
if len(prompt_ids.shape) > 2:
|
||||
# Assume sharded
|
||||
guidance_scale = guidance_scale[:, None]
|
||||
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||
controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0])
|
||||
if len(prompt_ids.shape) > 2:
|
||||
# Assume sharded
|
||||
controlnet_conditioning_scale = controlnet_conditioning_scale[:, None]
|
||||
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self,
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
else:
|
||||
images = self._generate(
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_params = params["safety_checker"]
|
||||
images_uint8_casted = (images * 255).round().astype("uint8")
|
||||
num_devices, batch_size = images.shape[:2]
|
||||
|
||||
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
||||
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
|
||||
images = np.asarray(images)
|
||||
|
||||
# block images
|
||||
if any(has_nsfw_concept):
|
||||
for i, is_nsfw in enumerate(has_nsfw_concept):
|
||||
if is_nsfw:
|
||||
images[i] = np.asarray(images_uint8_casted[i])
|
||||
|
||||
images = images.reshape(num_devices, batch_size, height, width, 3)
|
||||
else:
|
||||
images = np.asarray(images)
|
||||
has_nsfw_concept = False
|
||||
|
||||
if not return_dict:
|
||||
return (images, has_nsfw_concept)
|
||||
|
||||
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
|
||||
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
|
||||
@partial(
|
||||
jax.pmap,
|
||||
in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0),
|
||||
static_broadcasted_argnums=(0, 5),
|
||||
)
|
||||
def _p_generate(
|
||||
pipe,
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
):
|
||||
return pipe._generate(
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0,))
|
||||
def _p_get_has_nsfw_concepts(pipe, features, params):
|
||||
return pipe._get_has_nsfw_concepts(features, params)
|
||||
|
||||
|
||||
def unshard(x: jnp.ndarray):
|
||||
# einops.rearrange(x, 'd b ... -> (d b) ...')
|
||||
num_devices, batch_size = x.shape[:2]
|
||||
rest = x.shape[2:]
|
||||
return x.reshape(num_devices * batch_size, *rest)
|
||||
|
||||
|
||||
def preprocess(image, dtype):
|
||||
image = image.convert("RGB")
|
||||
w, h = image.size
|
||||
w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = jnp.array(image).astype(dtype) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
return image
|
||||
@@ -8,10 +8,10 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import SemanticStableDiffusionPipelineOutput
|
||||
|
||||
|
||||
|
||||
@@ -45,7 +45,6 @@ else:
|
||||
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
|
||||
from .pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
|
||||
from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
|
||||
@@ -130,7 +129,6 @@ if is_transformers_available() and is_flax_available():
|
||||
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
|
||||
from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
|
||||
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
@@ -12,526 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
# NOTE: This file is deprecated and will be removed in a future version.
|
||||
# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training.common_utils import shard
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
from ...utils import deprecate
|
||||
from ..controlnet.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline # noqa: F401
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
|
||||
deprecate(
|
||||
"stable diffusion controlnet",
|
||||
"0.22.0",
|
||||
"Importing `FlaxStableDiffusionControlNetPipeline` from diffusers.pipelines.stable_diffusion.flax_pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import FlaxStableDiffusionControlNetPipeline` instead.",
|
||||
standard_warn=False,
|
||||
stacklevel=3,
|
||||
)
|
||||
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring
|
||||
from ..pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
from . import FlaxStableDiffusionPipelineOutput
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||
DEBUG = False
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import jax
|
||||
>>> import numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from flax.jax_utils import replicate
|
||||
>>> from flax.training.common_utils import shard
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
|
||||
|
||||
|
||||
>>> def image_grid(imgs, rows, cols):
|
||||
... w, h = imgs[0].size
|
||||
... grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
... for i, img in enumerate(imgs):
|
||||
... grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
... return grid
|
||||
|
||||
|
||||
>>> def create_key(seed=0):
|
||||
... return jax.random.PRNGKey(seed)
|
||||
|
||||
|
||||
>>> rng = create_key(0)
|
||||
|
||||
>>> # get canny image
|
||||
>>> canny_image = load_image(
|
||||
... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
|
||||
... )
|
||||
|
||||
>>> prompts = "best quality, extremely detailed"
|
||||
>>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
>>> # load control net and stable diffusion v1-5
|
||||
>>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
||||
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
|
||||
... )
|
||||
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
|
||||
... )
|
||||
>>> params["controlnet"] = controlnet_params
|
||||
|
||||
>>> num_samples = jax.device_count()
|
||||
>>> rng = jax.random.split(rng, jax.device_count())
|
||||
|
||||
>>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
||||
>>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
|
||||
>>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
|
||||
|
||||
>>> p_params = replicate(params)
|
||||
>>> prompt_ids = shard(prompt_ids)
|
||||
>>> negative_prompt_ids = shard(negative_prompt_ids)
|
||||
>>> processed_image = shard(processed_image)
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt_ids=prompt_ids,
|
||||
... image=processed_image,
|
||||
... params=p_params,
|
||||
... prng_seed=rng,
|
||||
... num_inference_steps=50,
|
||||
... neg_prompt_ids=negative_prompt_ids,
|
||||
... jit=True,
|
||||
... ).images
|
||||
|
||||
>>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
|
||||
>>> output_images = image_grid(output_images, num_samples // 4, 4)
|
||||
>>> output_images.save("generated_image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance.
|
||||
|
||||
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`FlaxAutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`FlaxCLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
controlnet ([`FlaxControlNetModel`]:
|
||||
Provides additional conditioning to the unet during the denoising process.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
|
||||
[`FlaxDPMSolverMultistepScheduler`].
|
||||
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: FlaxAutoencoderKL,
|
||||
text_encoder: FlaxCLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: FlaxUNet2DConditionModel,
|
||||
controlnet: FlaxControlNetModel,
|
||||
scheduler: Union[
|
||||
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
|
||||
],
|
||||
safety_checker: FlaxStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def prepare_text_inputs(self, prompt: Union[str, List[str]]):
|
||||
if not isinstance(prompt, (str, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
return text_input.input_ids
|
||||
|
||||
def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
|
||||
if not isinstance(image, (Image.Image, list)):
|
||||
raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
image = [image]
|
||||
|
||||
processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
|
||||
|
||||
return processed_images
|
||||
|
||||
def _get_has_nsfw_concepts(self, features, params):
|
||||
has_nsfw_concepts = self.safety_checker(features, params)
|
||||
return has_nsfw_concepts
|
||||
|
||||
def _run_safety_checker(self, images, safety_model_params, jit=False):
|
||||
# safety_model_params should already be replicated when jit is True
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
|
||||
|
||||
if jit:
|
||||
features = shard(features)
|
||||
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
|
||||
has_nsfw_concepts = unshard(has_nsfw_concepts)
|
||||
safety_model_params = unreplicate(safety_model_params)
|
||||
else:
|
||||
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
|
||||
|
||||
images_was_copied = False
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
if not images_was_copied:
|
||||
images_was_copied = True
|
||||
images = images.copy()
|
||||
|
||||
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
warnings.warn(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned"
|
||||
" instead. Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
latents: Optional[jnp.array] = None,
|
||||
neg_prompt_ids: Optional[jnp.array] = None,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
):
|
||||
height, width = image.shape[-2:]
|
||||
if height % 64 != 0 or width % 64 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
|
||||
|
||||
# get prompt text embeddings
|
||||
prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
|
||||
|
||||
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
batch_size = prompt_ids.shape[0]
|
||||
|
||||
max_length = prompt_ids.shape[-1]
|
||||
|
||||
if neg_prompt_ids is None:
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
).input_ids
|
||||
else:
|
||||
uncond_input = neg_prompt_ids
|
||||
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
image = jnp.concatenate([image] * 2)
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.config.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if latents is None:
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
def loop_body(step, args):
|
||||
latents, scheduler_state = args
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
latents_input = jnp.concatenate([latents] * 2)
|
||||
|
||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||
|
||||
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet.apply(
|
||||
{"params": params["controlnet"]},
|
||||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=context,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet.apply(
|
||||
{"params": params["unet"]},
|
||||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=context,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
||||
return latents, scheduler_state
|
||||
|
||||
scheduler_state = self.scheduler.set_timesteps(
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
|
||||
)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * params["scheduler"].init_noise_sigma
|
||||
|
||||
if DEBUG:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
||||
else:
|
||||
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
return image
|
||||
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||
latents: jnp.array = None,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt_ids (`jnp.array`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
image (`jnp.array`):
|
||||
Array representing the ControlNet input condition. ControlNet use this input condition to generate
|
||||
guidance to Unet.
|
||||
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
|
||||
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
if isinstance(guidance_scale, float):
|
||||
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
|
||||
if len(prompt_ids.shape) > 2:
|
||||
# Assume sharded
|
||||
guidance_scale = guidance_scale[:, None]
|
||||
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||
controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0])
|
||||
if len(prompt_ids.shape) > 2:
|
||||
# Assume sharded
|
||||
controlnet_conditioning_scale = controlnet_conditioning_scale[:, None]
|
||||
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self,
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
else:
|
||||
images = self._generate(
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_params = params["safety_checker"]
|
||||
images_uint8_casted = (images * 255).round().astype("uint8")
|
||||
num_devices, batch_size = images.shape[:2]
|
||||
|
||||
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
||||
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
|
||||
images = np.asarray(images)
|
||||
|
||||
# block images
|
||||
if any(has_nsfw_concept):
|
||||
for i, is_nsfw in enumerate(has_nsfw_concept):
|
||||
if is_nsfw:
|
||||
images[i] = np.asarray(images_uint8_casted[i])
|
||||
|
||||
images = images.reshape(num_devices, batch_size, height, width, 3)
|
||||
else:
|
||||
images = np.asarray(images)
|
||||
has_nsfw_concept = False
|
||||
|
||||
if not return_dict:
|
||||
return (images, has_nsfw_concept)
|
||||
|
||||
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
|
||||
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
|
||||
@partial(
|
||||
jax.pmap,
|
||||
in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0),
|
||||
static_broadcasted_argnums=(0, 5),
|
||||
)
|
||||
def _p_generate(
|
||||
pipe,
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
):
|
||||
return pipe._generate(
|
||||
prompt_ids,
|
||||
image,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
latents,
|
||||
neg_prompt_ids,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0,))
|
||||
def _p_get_has_nsfw_concepts(pipe, features, params):
|
||||
return pipe._get_has_nsfw_concepts(features, params)
|
||||
|
||||
|
||||
def unshard(x: jnp.ndarray):
|
||||
# einops.rearrange(x, 'd b ... -> (d b) ...')
|
||||
num_devices, batch_size = x.shape[:2]
|
||||
rest = x.shape[2:]
|
||||
return x.reshape(num_devices * batch_size, *rest)
|
||||
|
||||
|
||||
def preprocess(image, dtype):
|
||||
image = image.convert("RGB")
|
||||
w, h = image.size
|
||||
w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = jnp.array(image).astype(dtype) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
return image
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -212,6 +212,36 @@ class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/pipelines/controlnet/__init__.py
Normal file
0
tests/pipelines/controlnet/__init__.py
Normal file
@@ -34,7 +34,10 @@ from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_de
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -42,7 +45,7 @@ torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
|
||||
class StableDiffusionControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -155,6 +158,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -307,7 +311,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase):
|
||||
class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
366
tests/pipelines/controlnet/test_controlnet_img2img.py
Normal file
366
tests/pipelines/controlnet/test_controlnet_img2img.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
|
||||
|
||||
import gc
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
DDIMScheduler,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, randn_tensor, slow, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
|
||||
class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionControlNetImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
control_image = randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
image = floats_tensor(control_image.shape, rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": image,
|
||||
"control_image": control_image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
|
||||
class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionControlNetImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet1 = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet2 = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
controlnet = MultiControlNetModel([controlnet1, controlnet2])
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
|
||||
control_image = [
|
||||
randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
),
|
||||
randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
),
|
||||
]
|
||||
|
||||
image = floats_tensor(control_image[0].shape, rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": image,
|
||||
"control_image": control_image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_save_pretrained_raise_not_implemented_exception(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
try:
|
||||
# save_pretrained is not implemented for Multi-ControlNet
|
||||
pipe.save_pretrained(tmpdir)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_float16(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_local(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_optional_components(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
|
||||
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "evil space-punk bird"
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
|
||||
output = pipe(
|
||||
prompt,
|
||||
image,
|
||||
control_image=control_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=50,
|
||||
strength=0.6,
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/img2img.npy"
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 9e-2
|
||||
379
tests/pipelines/controlnet/test_controlnet_inpaint.py
Normal file
379
tests/pipelines/controlnet/test_controlnet_inpaint.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This model implementation is heavily based on:
|
||||
|
||||
import gc
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
DDIMScheduler,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, randn_tensor, slow, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
|
||||
class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionControlNetInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset([])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=9,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
control_image = randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
init_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
init_image = init_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
|
||||
image = Image.fromarray(np.uint8(init_image)).convert("RGB").resize((64, 64))
|
||||
mask_image = Image.fromarray(np.uint8(init_image + 4)).convert("RGB").resize((64, 64))
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"control_image": control_image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
|
||||
class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionControlNetInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=9,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet1 = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet2 = ControlNetModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
controlnet = MultiControlNetModel([controlnet1, controlnet2])
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
|
||||
control_image = [
|
||||
randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
),
|
||||
randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
),
|
||||
]
|
||||
init_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
init_image = init_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
|
||||
image = Image.fromarray(np.uint8(init_image)).convert("RGB").resize((64, 64))
|
||||
mask_image = Image.fromarray(np.uint8(init_image + 4)).convert("RGB").resize((64, 64))
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"control_image": control_image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_save_pretrained_raise_not_implemented_exception(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
try:
|
||||
# save_pretrained is not implemented for Multi-ControlNet
|
||||
pipe.save_pretrained(tmpdir)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_float16(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_local(self):
|
||||
...
|
||||
|
||||
# override PipelineTesterMixin
|
||||
@unittest.skip("save pretrained not implemented")
|
||||
def test_save_load_optional_components(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
|
||||
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
).resize((512, 512))
|
||||
|
||||
prompt = "pitch black hole"
|
||||
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
control_image=control_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/inpaint.npy"
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 9e-2
|
||||
@@ -30,7 +30,7 @@ if is_flax_available():
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxStableDiffusionControlNetPipelineIntegrationTests(unittest.TestCase):
|
||||
class FlaxControlNetPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
@@ -46,9 +46,8 @@ class StableDiffusionImageVariationPipelineFastTests(
|
||||
pipeline_class = StableDiffusionImageVariationPipeline
|
||||
params = IMAGE_VARIATION_PARAMS
|
||||
batch_params = IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_params = frozenset([])
|
||||
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -47,9 +47,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipelin
|
||||
pipeline_class = StableDiffusionInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset(
|
||||
[]
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_params = frozenset([])
|
||||
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user