1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Torch 2.0 compile] Fix more torch compile breaks (#3313)

* Fix more torch compile breaks

* add tests

* Fix all

* fix controlnet

* fix more

* Add Horace He as co-author.
>
>
Co-authored-by: Horace He <horacehe2007@yahoo.com>

* Add Horace He as co-author.

Co-authored-by: Horace He <horacehe2007@yahoo.com>

---------

Co-authored-by: Horace He <horacehe2007@yahoo.com>
This commit is contained in:
Patrick von Platen
2023-05-02 19:51:00 +02:00
committed by Daniel Gu
parent 863bb75ea9
commit c8cc4f01ce
22 changed files with 219 additions and 78 deletions

View File

@@ -498,7 +498,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
@@ -517,7 +517,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample += controlnet_cond
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample,)
@@ -551,7 +551,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
@@ -559,13 +559,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# 6. scaling
if guess_mode:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
scales *= conditioning_scale
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample *= scales[-1] # last one
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [

View File

@@ -740,7 +740,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,)
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples

View File

@@ -457,7 +457,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -728,7 +728,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -736,7 +737,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -745,7 +746,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -918,7 +918,8 @@ class IFImg2ImgPipeline(DiffusionPipeline):
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -930,8 +931,8 @@ class IFImg2ImgPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

View File

@@ -1036,7 +1036,8 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -1048,8 +1049,8 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

View File

@@ -1033,7 +1033,8 @@ class IFInpaintingPipeline(DiffusionPipeline):
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -1047,8 +1048,8 @@ class IFInpaintingPipeline(DiffusionPipeline):
prev_intermediate_images = intermediate_images
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]
intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images

View File

@@ -1143,7 +1143,8 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -1157,8 +1158,8 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
prev_intermediate_images = intermediate_images
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]
intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images

View File

@@ -886,7 +886,8 @@ class IFSuperResolutionPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -898,8 +899,8 @@ class IFSuperResolutionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs
).prev_sample
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

View File

@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from torch import nn
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
@@ -579,9 +580,20 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
)
# Check `image`
if isinstance(self.controlnet, ControlNetModel):
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
self.check_image(image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")
@@ -600,10 +612,18 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
assert False
# Check `controlnet_conditioning_scale`
if isinstance(self.controlnet, ControlNetModel):
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
@@ -910,7 +930,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
)
# 4. Prepare image
if isinstance(self.controlnet, ControlNetModel):
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
image = self.prepare_image(
image=image,
width=width,
@@ -922,7 +949,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(self.controlnet, MultiControlNetModel):
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
images = []
for image_ in image:
@@ -1006,7 +1037,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -1014,7 +1046,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

View File

@@ -677,7 +677,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]
# perform guidance
if do_classifier_free_guidance:
@@ -685,7 +687,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

View File

@@ -462,7 +462,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -734,7 +734,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -742,7 +743,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -751,7 +752,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -878,7 +878,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]
# perform guidance
if do_classifier_free_guidance:
@@ -886,7 +888,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

View File

@@ -690,7 +690,9 @@ class StableDiffusionInpaintPipelineLegacy(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]
# perform guidance
if do_classifier_free_guidance:
@@ -698,7 +700,7 @@ class StableDiffusionInpaintPipelineLegacy(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# masking
if add_predicted_noise:
init_latents_proper = self.scheduler.add_noise(

View File

@@ -346,7 +346,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
# predict the noise residual
noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(
scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False
)[0]
# Hack:
# For karras style schedulers the model does classifer free guidance using the
@@ -376,7 +378,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
noise_pred = (noise_pred - latents) / (-sigma)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

View File

@@ -678,8 +678,12 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level
).sample
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -687,7 +691,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

View File

@@ -830,7 +830,8 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timestep=t,
sample=prior_latents,
**prior_extra_step_kwargs,
).prev_sample
return_dict=False,
)[0]
if callback is not None and i % callback_steps == 0:
callback(i, t, prior_latents)
@@ -903,7 +904,8 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
encoder_hidden_states=prompt_embeds,
class_labels=image_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -911,7 +913,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

View File

@@ -799,7 +799,8 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
encoder_hidden_states=prompt_embeds,
class_labels=image_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
@@ -807,7 +808,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

View File

@@ -843,7 +843,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,)
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples

View File

@@ -866,6 +866,28 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
max_diff = np.abs(expected_image - image).max()
assert max_diff < 5e-2
def test_stable_diffusion_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.unet.to(memory_format=torch.channels_last)
sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
assert np.abs(image_slice - expected_slice).max() < 5e-3
@slow
@require_torch_gpu
@@ -922,28 +944,6 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
assert np.max(np.abs(image - image_ckpt)) < 1e-4
def test_stable_diffusion_compile(self):
if version.parse(torch.__version__) >= version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.unet.to(memory_format=torch.channels_last)
sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
assert np.abs(image_slice - expected_slice).max() < 1e-4
@nightly
@require_torch_gpu

View File

@@ -19,6 +19,7 @@ import unittest
import numpy as np
import torch
from packaging import version
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
@@ -585,6 +586,42 @@ class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.to("cuda")
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.controlnet.to(memory_format=torch.channels_last)
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
output = pipe(prompt, image, generator=generator, output_type="np")
image = output.images[0]
assert image.shape == (768, 512, 3)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
)
assert np.abs(expected_image - image).max() < 1e-1
@slow
@require_torch_gpu

View File

@@ -19,6 +19,7 @@ import unittest
import numpy as np
import torch
from packaging import version
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
@@ -460,6 +461,28 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
def test_img2img_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 768, 3)
expected_slice = np.array([0.0593, 0.0607, 0.0851, 0.0582, 0.0636, 0.0721, 0.0751, 0.0981, 0.0781])
assert np.abs(expected_slice - image_slice).max() < 1e-3
@nightly
@require_torch_gpu

View File

@@ -19,6 +19,7 @@ import unittest
import numpy as np
import torch
from packaging import version
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -274,6 +275,31 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
# make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9
def test_inpaint_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None
)
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])
assert np.abs(expected_slice - image_slice).max() < 1e-4
assert np.abs(expected_slice - image_slice).max() < 1e-3
@nightly
@require_torch_gpu