mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Docs] Add TGATE in section optimization (#7639)
* Create tgate.md * Update _toctree.yml * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/tgate.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update tgate.md * Update tgate.md --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -172,6 +172,8 @@
|
||||
title: Token merging
|
||||
- local: optimization/deepcache
|
||||
title: DeepCache
|
||||
- local: optimization/tgate
|
||||
title: TGATE
|
||||
title: General optimizations
|
||||
- sections:
|
||||
- local: using-diffusers/stable_diffusion_jax_how_to
|
||||
|
||||
179
docs/source/en/optimization/tgate.md
Normal file
179
docs/source/en/optimization/tgate.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# T-GATE
|
||||
|
||||
[T-GATE](https://github.com/HaozheLiu-ST/T-GATE/tree/main) accelerates inference for [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [PixArt](../api/pipelines/pixart), and [Latency Consistency Model](../api/pipelines/latent_consistency_models.md) pipelines by skipping the cross-attention calculation once it converges. This method doesn't require any additional training and it can speed up inference from 10-50%. T-GATE is also compatible with other optimization methods like [DeepCache](./deepcache).
|
||||
|
||||
Before you begin, make sure you install T-GATE.
|
||||
|
||||
```bash
|
||||
pip install tgate
|
||||
pip install -U pytorch diffusers transformers accelerate DeepCache
|
||||
```
|
||||
|
||||
|
||||
To use T-GATE with a pipeline, you need to use its corresponding loader.
|
||||
|
||||
| Pipeline | T-GATE Loader |
|
||||
|---|---|
|
||||
| PixArt | TgatePixArtLoader |
|
||||
| Stable Diffusion XL | TgateSDXLLoader |
|
||||
| Stable Diffusion XL + DeepCache | TgateSDXLDeepCacheLoader |
|
||||
| Stable Diffusion | TgateSDLoader |
|
||||
| Stable Diffusion + DeepCache | TgateSDDeepCacheLoader |
|
||||
|
||||
Next, create a `TgateLoader` with a pipeline, the gate step (the time step to stop calculating the cross attention), and the number of inference steps. Then call the `tgate` method on the pipeline with a prompt, gate step, and the number of inference steps.
|
||||
|
||||
Let's see how to enable this for several different pipelines.
|
||||
|
||||
<hfoptions id="pipelines">
|
||||
<hfoption id="PixArt">
|
||||
|
||||
Accelerate `PixArtAlphaPipeline` with T-GATE:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import PixArtAlphaPipeline
|
||||
from tgate import TgatePixArtLoader
|
||||
|
||||
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
||||
pipe = TgatePixArtLoader(
|
||||
pipe,
|
||||
gate_step=8,
|
||||
num_inference_steps=25,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe.tgate(
|
||||
"An alpaca made of colorful building blocks, cyberpunk.",
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step,
|
||||
).images[0]
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Stable Diffusion XL">
|
||||
|
||||
Accelerate `StableDiffusionXLPipeline` with T-GATE:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
from tgate import TgateSDXLLoader
|
||||
gate_step = 10
|
||||
inference_step = 25
|
||||
pipe = TgateSDXLLoader(
|
||||
pipe,
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe.tgate(
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step
|
||||
).images[0]
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="StableDiffusionXL with DeepCache">
|
||||
|
||||
Accelerate `StableDiffusionXLPipeline` with [DeepCache](https://github.com/horseee/DeepCache) and T-GATE:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
from tgate import TgateSDXLDeepCacheLoader
|
||||
gate_step = 10
|
||||
inference_step = 25
|
||||
pipe = TgateSDXLDeepCacheLoader(
|
||||
pipe,
|
||||
cache_interval=3,
|
||||
cache_branch_id=0,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe.tgate(
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step
|
||||
).images[0]
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Latent Consistency Model">
|
||||
|
||||
Accelerate `latent-consistency/lcm-sdxl` with T-GATE:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import UNet2DConditionModel, LCMScheduler
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"latent-consistency/lcm-sdxl",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
unet=unet,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
from tgate import TgateSDXLLoader
|
||||
gate_step = 1
|
||||
inference_step = 4
|
||||
pipe = TgateSDXLLoader(
|
||||
pipe,
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step,
|
||||
lcm=True
|
||||
).to("cuda")
|
||||
|
||||
image = pipe.tgate(
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step
|
||||
).images[0]
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
T-GATE also supports [`StableDiffusionPipeline`] and [PixArt-alpha/PixArt-LCM-XL-2-1024-MS](https://hf.co/PixArt-alpha/PixArt-LCM-XL-2-1024-MS).
|
||||
|
||||
## Benchmarks
|
||||
| Model | MACs | Param | Latency | Zero-shot 10K-FID on MS-COCO |
|
||||
|-----------------------|----------|-----------|---------|---------------------------|
|
||||
| SD-1.5 | 16.938T | 859.520M | 7.032s | 23.927 |
|
||||
| SD-1.5 w/ T-GATE | 9.875T | 815.557M | 4.313s | 20.789 |
|
||||
| SD-2.1 | 38.041T | 865.785M | 16.121s | 22.609 |
|
||||
| SD-2.1 w/ T-GATE | 22.208T | 815.433 M | 9.878s | 19.940 |
|
||||
| SD-XL | 149.438T | 2.570B | 53.187s | 24.628 |
|
||||
| SD-XL w/ T-GATE | 84.438T | 2.024B | 27.932s | 22.738 |
|
||||
| Pixart-Alpha | 107.031T | 611.350M | 61.502s | 38.669 |
|
||||
| Pixart-Alpha w/ T-GATE | 65.318T | 462.585M | 37.867s | 35.825 |
|
||||
| DeepCache (SD-XL) | 57.888T | - | 19.931s | 23.755 |
|
||||
| DeepCache w/ T-GATE | 43.868T | - | 14.666s | 23.999 |
|
||||
| LCM (SD-XL) | 11.955T | 2.570B | 3.805s | 25.044 |
|
||||
| LCM w/ T-GATE | 11.171T | 2.024B | 3.533s | 25.028 |
|
||||
| LCM (Pixart-Alpha) | 8.563T | 611.350M | 4.733s | 36.086 |
|
||||
| LCM w/ T-GATE | 7.623T | 462.585M | 4.543s | 37.048 |
|
||||
|
||||
The latency is tested on an NVIDIA 1080TI, MACs and Params are calculated with [calflops](https://github.com/MrYxJ/calculate-flops.pytorch), and the FID is calculated with [PytorchFID](https://github.com/mseitzer/pytorch-fid).
|
||||
Reference in New Issue
Block a user