1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow fp16 attn for x4 upscaler (#3239)

* Add all files

* update

* Make sure vae is memory efficient for PT 1

* make style
This commit is contained in:
Patrick von Platen
2023-04-26 12:16:06 +02:00
committed by GitHub
parent da2ce1a6b9
commit abbf3c1adf
2 changed files with 20 additions and 4 deletions

View File

@@ -212,6 +212,7 @@ class Decoder(nn.Module):
sample = z
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
@@ -222,6 +223,7 @@ class Decoder(nn.Module):
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -229,6 +231,7 @@ class Decoder(nn.Module):
else:
# middle
sample = self.mid_block(sample)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:

View File

@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Union
import numpy as np
import PIL
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
@@ -698,10 +699,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32)
# TODO(Patrick, William) - clean up when attention is refactored
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if not use_torch_2_0_attn and not use_xformers:
self.vae.post_quant_conv.to(latents.dtype)
self.vae.decoder.conv_in.to(latents.dtype)
self.vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
# 11. Convert to PIL
# has_nsfw_concept = False
if output_type == "pil":
image = self.decode_latents(latents.float())
image = self.decode_latents(latents)
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
image = self.numpy_to_pil(image)
@@ -710,11 +723,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
if self.watermarker is not None:
image = self.watermarker.apply_watermark(image)
elif output_type == "pt":
latents = 1 / self.vae.config.scaling_factor * latents.float()
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
has_nsfw_concept = None
else:
image = self.decode_latents(latents.float())
image = self.decode_latents(latents)
has_nsfw_concept = None
# Offload last model to CPU