From 48ce118d1ccbd73788f2b2b75c83d9abadeda4bd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 23 Oct 2023 23:41:52 +0530 Subject: [PATCH] [torch.compile] fix graph break problems partially (#5453) * fix: controlnet graph? * fix: sample * fix: * remove print * styling * fix-copies * prevent more graph breaks? * prevent more graph breaks? * see? * revert. * compilation. * rpopagate changes to controlnet sdxl pipeline too. * add: clean version checking. --- src/diffusers/models/controlnet.py | 1 - src/diffusers/models/unet_2d_condition.py | 8 +++++--- .../pipelines/controlnet/pipeline_controlnet.py | 9 ++++++++- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 9 ++++++++- .../t2i_adapter/pipeline_stable_diffusion_adapter.py | 3 ++- .../pipelines/versatile_diffusion/modeling_text_unet.py | 8 +++++--- 6 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index c0d2da9b8c..052335f6c5 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -817,7 +817,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 - scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 6a59e5ea5c..fdd0de0a30 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -874,9 +874,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) forward_upsample_size = False upsample_size = None - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - # Forward upsample size to force interpolation output size. - forward_upsample_size = True + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 41df612588..6944d93312 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -35,7 +35,7 @@ from ...utils import ( scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -976,8 +976,15 @@ class StableDiffusionControlNetPipeline( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 786c208e16..d6278c4f04 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -36,7 +36,7 @@ from ...models.attention_processor import ( from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -1144,8 +1144,15 @@ class StableDiffusionXLControlNetPipeline( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index d2b9bfb00d..103cd70952 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -814,7 +814,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_intrablock_additional_residuals=[state.clone() for state in adapter_state], - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 4462963a8e..d936666d61 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1084,9 +1084,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): forward_upsample_size = False upsample_size = None - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - # Forward upsample size to force interpolation output size. - forward_upsample_size = True + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: