From e5f232f76bc8a6f5167285f414f208517861083f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Nov 2023 20:36:33 +0530 Subject: [PATCH] [Docs] add: 8bit inference with pixart alpha (#5814) * add: 8bit inference with pixart alpha * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add: note on 4bit. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * address comment --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Patrick von Platen --- docs/source/en/api/pipelines/pixart.md | 106 +++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/docs/source/en/api/pipelines/pixart.md b/docs/source/en/api/pipelines/pixart.md index 6fa44cd508..7d8ff2b36b 100644 --- a/docs/source/en/api/pipelines/pixart.md +++ b/docs/source/en/api/pipelines/pixart.md @@ -35,6 +35,112 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) +## Inference with under 8GB GPU VRAM + +Run the [`PixArtAlphaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example. + +First, install the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library: + +```bash +pip install -U bitsandbytes +``` + +Then load the text encoder in 8-bit: + +```python +from transformers import T5EncoderModel +from diffusers import PixArtAlphaPipeline +import torch + +text_encoder = T5EncoderModel.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + subfolder="text_encoder", + load_in_8bit=True, + device_map="auto", + +) +pipe = PixArtAlphaPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + text_encoder=text_encoder, + transformer=None, + device_map="auto" +) +``` + +Now, use the `pipe` to encode a prompt: + +```python +with torch.no_grad(): + prompt = "cute cat" + prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt) +``` + +Since text embeddings have been computed, remove the `text_encoder` and `pipe` from the memory, and free up som GPU VRAM: + +```python +import gc + +def flush(): + gc.collect() + torch.cuda.empty_cache() + +del text_encoder +del pipe +flush() +``` + +Then compute the latents with the prompt embeddings as inputs: + +```python +pipe = PixArtAlphaPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + text_encoder=None, + torch_dtype=torch.float16, +).to("cuda") + +latents = pipe( + negative_prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_images_per_prompt=1, + output_type="latent", +).images + +del pipe.transformer +flush() +``` + + + +Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded. + + + +Once the latents are computed, pass it off to the VAE to decode into a real image: + +```python +with torch.no_grad(): + image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] +image = pipe.image_processor.postprocess(image, output_type="pil")[0] +image.save("cat.png") +``` + +By deleting components you aren't using and flushing the GPU VRAM, you should be able to run [`PixArtAlphaPipeline`] with under 8GB GPU VRAM. + +![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png) + +If you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e). + + + +Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit. + + + +While loading the `text_encoder`, you set `load_in_8bit` to `True`. You could also specify `load_in_4bit` to bring your memory requirements down even further to under 7GB. + ## PixArtAlphaPipeline [[autodoc]] PixArtAlphaPipeline