mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Torch compile graph fix (#3286)
* fix more * Fix more * fix more * Apply suggestions from code review * fix * make style * make fix-copies * fix * make sure torch compile * Clean * fix test
This commit is contained in:
committed by
GitHub
parent
709cf554f6
commit
0e82fb19e1
@@ -18,6 +18,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import maybe_allow_in_graph
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings
|
||||
@@ -193,6 +194,7 @@ class AttentionBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
@@ -17,7 +17,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils import deprecate, logging, maybe_allow_in_graph
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ else:
|
||||
xformers = None
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class Attention(nn.Module):
|
||||
r"""
|
||||
A cross attention layer.
|
||||
|
||||
@@ -77,8 +77,14 @@ def get_parameter_device(parameter: torch.nn.Module):
|
||||
|
||||
def get_parameter_dtype(parameter: torch.nn.Module):
|
||||
try:
|
||||
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
||||
return next(parameters_and_buffers).dtype
|
||||
params = tuple(parameter.parameters())
|
||||
if len(params) > 0:
|
||||
return params[0].dtype
|
||||
|
||||
buffers = tuple(parameter.buffers())
|
||||
if len(buffers) > 0:
|
||||
return buffers[0].dtype
|
||||
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
|
||||
@@ -560,7 +560,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
@@ -868,15 +869,16 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
@@ -949,13 +951,13 @@ class DownBlock2D(nn.Module):
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
@@ -1342,13 +1344,13 @@ class ResnetDownsampleBlock2D(nn.Module):
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
@@ -1466,13 +1468,13 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
@@ -1859,7 +1861,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
||||
@@ -682,7 +682,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
@@ -697,7 +697,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
# there might be better ways to encapsulate this.
|
||||
class_labels = class_labels.to(dtype=sample.dtype)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
||||
|
||||
if self.config.class_embeddings_concat:
|
||||
emb = torch.cat([emb, class_emb], dim=-1)
|
||||
|
||||
@@ -437,7 +437,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
@@ -683,7 +683,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -691,7 +692,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -793,7 +793,8 @@ class IFPipeline(DiffusionPipeline):
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -805,8 +806,8 @@ class IFPipeline(DiffusionPipeline):
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs
|
||||
).prev_sample
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
|
||||
)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
@@ -829,7 +830,7 @@ class IFPipeline(DiffusionPipeline):
|
||||
|
||||
# 11. Apply watermark
|
||||
if self.watermarker is not None:
|
||||
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
|
||||
image = self.watermarker.apply_watermark(image, self.unet.config.sample_size)
|
||||
elif output_type == "pt":
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
@@ -256,7 +256,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -134,7 +134,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -516,7 +516,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -440,7 +440,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
@@ -686,7 +686,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -694,7 +695,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -454,7 +454,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -496,7 +496,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -326,7 +326,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -648,7 +648,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -195,7 +195,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -525,7 +525,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -446,7 +446,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -656,7 +656,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -358,7 +358,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -221,7 +221,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -385,7 +385,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -349,7 +349,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -590,7 +590,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -366,7 +366,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
@@ -619,7 +619,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
|
||||
|
||||
def get_map_size(module, input, output):
|
||||
nonlocal map_size
|
||||
map_size = output.sample.shape[-2:]
|
||||
map_size = output[0].shape[-2:]
|
||||
|
||||
with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size):
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
@@ -373,7 +373,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -475,7 +475,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -430,7 +430,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -364,7 +364,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -785,7 +785,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
@@ -800,7 +800,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
# there might be better ways to encapsulate this.
|
||||
class_labels = class_labels.to(dtype=sample.dtype)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
||||
|
||||
if self.config.class_embeddings_concat:
|
||||
emb = torch.cat([emb, class_emb], dim=-1)
|
||||
@@ -1081,13 +1081,13 @@ class DownBlockFlat(nn.Module):
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
@@ -1211,15 +1211,16 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
@@ -1424,7 +1425,8 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -1528,7 +1530,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -330,7 +330,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -190,7 +190,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -247,7 +247,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@@ -101,6 +101,7 @@ if is_torch_available():
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from .torch_utils import maybe_allow_in_graph
|
||||
|
||||
from .testing_utils import export_to_video
|
||||
|
||||
|
||||
@@ -25,6 +25,13 @@ if is_torch_available():
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
from torch._dynamo import allow_in_graph as maybe_allow_in_graph
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
|
||||
def maybe_allow_in_graph(cls):
|
||||
return cls
|
||||
|
||||
|
||||
def randn_tensor(
|
||||
shape: Union[Tuple, List],
|
||||
|
||||
@@ -22,6 +22,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from packaging import version
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
@@ -921,6 +922,28 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
|
||||
assert np.max(np.abs(image - image_ckpt)) < 1e-4
|
||||
|
||||
def test_stable_diffusion_compile(self):
|
||||
if version.parse(torch.__version__) >= version.parse("2.0"):
|
||||
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
|
||||
return
|
||||
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
|
||||
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
|
||||
sd_pipe.unet.to(memory_format=torch.channels_last)
|
||||
sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-4
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user