mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Removing autocast for 35-25% speedup. (autocast considered harmful). (#511)
* Removing `autocast` for `35-25% speedup`. * iQuality * Adding a slow test. * Fixing mps noise generation. * Raising error on wrong device, instead of just casting on behalf of user. * Quality. * fix merge Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
This commit is contained in:
29
README.md
29
README.md
@@ -76,15 +76,13 @@ You need to accept the model license before downloading or using the Stable Diff
|
||||
|
||||
```python
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).images[0]
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
**Note**: If you don't want to use the token, you can also simply download the model weights
|
||||
@@ -104,8 +102,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).images[0]
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
If you are limited by GPU memory, you might want to consider using the model in `fp16` as
|
||||
@@ -123,8 +120,7 @@ pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
pipe.enable_attention_slicing()
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).images[0]
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
Finally, if you wish to use a different scheduler, you can simply instantiate
|
||||
@@ -149,8 +145,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).images[0]
|
||||
image = pipe(prompt).images[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
@@ -160,7 +155,6 @@ image.save("astronaut_rides_horse.png")
|
||||
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
|
||||
|
||||
```python
|
||||
from torch import autocast
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -190,8 +184,7 @@ init_image = init_image.resize((768, 512))
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
with autocast("cuda"):
|
||||
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
|
||||
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
|
||||
|
||||
images[0].save("fantasy_landscape.png")
|
||||
```
|
||||
@@ -204,7 +197,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
|
||||
```python
|
||||
from io import BytesIO
|
||||
|
||||
from torch import autocast
|
||||
import torch
|
||||
import requests
|
||||
import PIL
|
||||
@@ -234,8 +226,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
pipe = pipe.to(device)
|
||||
|
||||
prompt = "a cat sitting on a bench"
|
||||
with autocast("cuda"):
|
||||
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
|
||||
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
|
||||
|
||||
images[0].save("cat_on_bench.png")
|
||||
```
|
||||
@@ -258,7 +249,6 @@ If you want to run the code yourself 💻, you can try out:
|
||||
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
|
||||
```python
|
||||
# !pip install diffusers transformers
|
||||
from torch import autocast
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
device = "cuda"
|
||||
@@ -270,8 +260,7 @@ ldm = ldm.to(device)
|
||||
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
with autocast(device):
|
||||
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
|
||||
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
|
||||
|
||||
# save image
|
||||
image.save("squirrel.png")
|
||||
@@ -279,7 +268,6 @@ image.save("squirrel.png")
|
||||
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
|
||||
```python
|
||||
# !pip install diffusers
|
||||
from torch import autocast
|
||||
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
|
||||
|
||||
model_id = "google/ddpm-celebahq-256"
|
||||
@@ -290,8 +278,7 @@ ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline wi
|
||||
ddpm.to(device)
|
||||
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
with autocast("cuda"):
|
||||
image = ddpm().images[0]
|
||||
image = ddpm().images[0]
|
||||
|
||||
# save image
|
||||
image.save("ddpm_generated_image.png")
|
||||
|
||||
@@ -266,7 +266,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
emb = self.time_embedding(t_emb.to(self.dtype))
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
@@ -86,15 +86,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
|
||||
|
||||
```python
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).images[0]
|
||||
image = pipe(prompt).images[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
@@ -104,7 +102,6 @@ image.save("astronaut_rides_horse.png")
|
||||
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
|
||||
|
||||
```python
|
||||
from torch import autocast
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
@@ -129,8 +126,7 @@ init_image = init_image.resize((768, 512))
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
with autocast("cuda"):
|
||||
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
|
||||
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
|
||||
|
||||
images[0].save("fantasy_landscape.png")
|
||||
```
|
||||
@@ -148,7 +144,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
|
||||
```python
|
||||
from io import BytesIO
|
||||
|
||||
from torch import autocast
|
||||
import requests
|
||||
import PIL
|
||||
|
||||
@@ -173,8 +168,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
).to(device)
|
||||
|
||||
prompt = "a cat sitting on a bench"
|
||||
with autocast("cuda"):
|
||||
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
|
||||
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
|
||||
|
||||
images[0].save("cat_on_bench.png")
|
||||
```
|
||||
|
||||
@@ -59,15 +59,13 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
|
||||
|
||||
```python
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).images[0]
|
||||
image = pipe(prompt).sample[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
@@ -76,7 +74,6 @@ image.save("astronaut_rides_horse.png")
|
||||
|
||||
```python
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
||||
@@ -88,8 +85,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
||||
).to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).images[0]
|
||||
image = pipe(prompt).sample[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
@@ -98,7 +94,6 @@ image.save("astronaut_rides_horse.png")
|
||||
|
||||
```python
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
||||
|
||||
lms = LMSDiscreteScheduler(
|
||||
@@ -114,8 +109,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
||||
).to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).images[0]
|
||||
image = pipe(prompt).sample[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
|
||||
@@ -260,19 +260,20 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_device = "cpu" if self.device.type == "mps" else self.device
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device=latents_device,
|
||||
dtype=text_embeddings.dtype,
|
||||
)
|
||||
if self.device.type == "mps":
|
||||
# randn does not exist on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||||
self.device
|
||||
)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(latents_device)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -1214,6 +1214,37 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
assert mem_bytes > 3.75 * 10**9
|
||||
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_text2img_pipeline_fp16(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
|
||||
).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output_chunked = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||
)
|
||||
image_chunked = output_chunked.images
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
output = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||
)
|
||||
image = output.images
|
||||
|
||||
# Make sure results are close enough
|
||||
diff = np.abs(image_chunked.flatten() - image.flatten())
|
||||
# They ARE different since ops are not run always at the same precision
|
||||
# however, they should be extremely close.
|
||||
assert diff.mean() < 2e-2
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_text2img_pipeline(self):
|
||||
|
||||
Reference in New Issue
Block a user