mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow fp16 attn for x4 upscaler (#3239)
* Add all files * update * Make sure vae is memory efficient for PT 1 * make style
This commit is contained in:
committed by
GitHub
parent
da2ce1a6b9
commit
abbf3c1adf
@@ -212,6 +212,7 @@ class Decoder(nn.Module):
|
||||
sample = z
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -222,6 +223,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
@@ -229,6 +231,7 @@ class Decoder(nn.Module):
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Union
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
@@ -698,10 +699,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
# TODO(Patrick, William) - clean up when attention is refactored
|
||||
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
|
||||
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if not use_torch_2_0_attn and not use_xformers:
|
||||
self.vae.post_quant_conv.to(latents.dtype)
|
||||
self.vae.decoder.conv_in.to(latents.dtype)
|
||||
self.vae.decoder.mid_block.to(latents.dtype)
|
||||
else:
|
||||
latents = latents.float()
|
||||
|
||||
# 11. Convert to PIL
|
||||
# has_nsfw_concept = False
|
||||
if output_type == "pil":
|
||||
image = self.decode_latents(latents.float())
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
image = self.numpy_to_pil(image)
|
||||
@@ -710,11 +723,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
if self.watermarker is not None:
|
||||
image = self.watermarker.apply_watermark(image)
|
||||
elif output_type == "pt":
|
||||
latents = 1 / self.vae.config.scaling_factor * latents.float()
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
image = self.decode_latents(latents.float())
|
||||
image = self.decode_latents(latents)
|
||||
has_nsfw_concept = None
|
||||
|
||||
# Offload last model to CPU
|
||||
|
||||
Reference in New Issue
Block a user