1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Removing autocast for 35-25% speedup. (autocast considered harmful). (#511)

* Removing `autocast` for `35-25% speedup`.

* iQuality

* Adding a slow test.

* Fixing mps noise generation.

* Raising error on wrong device, instead of just casting on behalf of user.

* Quality.

* fix merge

Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
This commit is contained in:
Nicolas Patry
2022-10-05 15:33:13 +02:00
committed by GitHub
parent 6b09f370c4
commit 3dcc75cbd4
6 changed files with 60 additions and 48 deletions

View File

@@ -76,15 +76,13 @@ You need to accept the model license before downloading or using the Stable Diff
```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]
```
**Note**: If you don't want to use the token, you can also simply download the model weights
@@ -104,8 +102,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]
```
If you are limited by GPU memory, you might want to consider using the model in `fp16` as
@@ -123,8 +120,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_attention_slicing()
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]
```
Finally, if you wish to use a different scheduler, you can simply instantiate
@@ -149,8 +145,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
@@ -160,7 +155,6 @@ image.save("astronaut_rides_horse.png")
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
```python
from torch import autocast
import requests
import torch
from PIL import Image
@@ -190,8 +184,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
@@ -204,7 +197,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
```python
from io import BytesIO
from torch import autocast
import torch
import requests
import PIL
@@ -234,8 +226,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
pipe = pipe.to(device)
prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```
@@ -258,7 +249,6 @@ If you want to run the code yourself 💻, you can try out:
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
```python
# !pip install diffusers transformers
from torch import autocast
from diffusers import DiffusionPipeline
device = "cuda"
@@ -270,8 +260,7 @@ ldm = ldm.to(device)
# run pipeline in inference (sample random noise and denoise)
prompt = "A painting of a squirrel eating a burger"
with autocast(device):
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
# save image
image.save("squirrel.png")
@@ -279,7 +268,6 @@ image.save("squirrel.png")
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
```python
# !pip install diffusers
from torch import autocast
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
model_id = "google/ddpm-celebahq-256"
@@ -290,8 +278,7 @@ ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline wi
ddpm.to(device)
# run pipeline in inference (sample random noise and denoise)
with autocast("cuda"):
image = ddpm().images[0]
image = ddpm().images[0]
# save image
image.save("ddpm_generated_image.png")

View File

@@ -266,7 +266,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb.to(self.dtype))
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
# 2. pre-process
sample = self.conv_in(sample)

View File

@@ -86,15 +86,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
@@ -104,7 +102,6 @@ image.save("astronaut_rides_horse.png")
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
```python
from torch import autocast
import requests
from PIL import Image
from io import BytesIO
@@ -129,8 +126,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
@@ -148,7 +144,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
```python
from io import BytesIO
from torch import autocast
import requests
import PIL
@@ -173,8 +168,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
).to(device)
prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```

View File

@@ -59,15 +59,13 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
@@ -76,7 +74,6 @@ image.save("astronaut_rides_horse.png")
```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
@@ -88,8 +85,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
).to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
@@ -98,7 +94,6 @@ image.save("astronaut_rides_horse.png")
```python
# make sure you're logged in with `huggingface-cli login`
from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler(
@@ -114,8 +109,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
).to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```

View File

@@ -260,19 +260,20 @@ class StableDiffusionPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=latents_device,
dtype=text_embeddings.dtype,
)
if self.device.type == "mps":
# randn does not exist on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(latents_device)
latents = latents.to(self.device)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

View File

@@ -1214,6 +1214,37 @@ class PipelineTesterMixin(unittest.TestCase):
assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_pipeline_fp16(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
).to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
output_chunked = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images
# Make sure results are close enough
diff = np.abs(image_chunked.flatten() - image.flatten())
# They ARE different since ops are not run always at the same precision
# however, they should be extremely close.
assert diff.mean() < 2e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_pipeline(self):