diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c8bb2b8c0b..8ef2c2693b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -194,6 +194,8 @@ title: Consistency Models - local: api/pipelines/controlnet title: ControlNet + - local: api/pipelines/controlnet_sdxl + title: ControlNet with Stable Diffusion XL - local: api/pipelines/cycle_diffusion title: Cycle Diffusion - local: api/pipelines/dance_diffusion diff --git a/docs/source/en/api/pipelines/controlnet_sdxl.md b/docs/source/en/api/pipelines/controlnet_sdxl.md new file mode 100644 index 0000000000..8f0d759218 --- /dev/null +++ b/docs/source/en/api/pipelines/controlnet_sdxl.md @@ -0,0 +1,35 @@ + + +# ControlNet with Stable Diffusion XL + +[Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. + +Using a pretrained model, we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details. + +The abstract from the paper is: + +*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.* + +We provide support using ControlNets with [Stable Diffusion XL](./stable_diffusion/stable_diffusion_xl.md) (SDXL). + +There are not many ControlNet checkpoints that are compatible with SDXL at the moment. So, we trained one using Canny edge maps as the conditioning images. To know more, check out the [model card](https://huggingface.co/diffusers/controlnet-sdxl-1.0). We encourage you to train custom ControlNets; we provide a [training script](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md) for this. + +You can find some results below: + + + + +## StableDiffusionXLControlNetPipeline +[[autodoc]] StableDiffusionXLControlNetPipeline + - all + - __call__ \ No newline at end of file diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index bb95682215..1a3887b987 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -58,8 +58,43 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py - >>> # To be updated when there's a useful ControlNet checkpoint - >>> # compatible with SDXL. + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-sdxl-1.0", torch_dtype=torch.float16) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] ``` """