mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[torch.compile] fix graph break problems partially (#5453)
* fix: controlnet graph? * fix: sample * fix: * remove print * styling * fix-copies * prevent more graph breaks? * prevent more graph breaks? * see? * revert. * compilation. * rpopagate changes to controlnet sdxl pipeline too. * add: clean version checking.
This commit is contained in:
@@ -817,7 +817,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
|
||||
@@ -874,9 +874,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
# Forward upsample size to force interpolation output size.
|
||||
forward_upsample_size = True
|
||||
for dim in sample.shape[-2:]:
|
||||
if dim % default_overall_up_factor != 0:
|
||||
# Forward upsample size to force interpolation output size.
|
||||
forward_upsample_size = True
|
||||
break
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||
# expects mask of shape:
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -976,8 +976,15 @@ class StableDiffusionControlNetPipeline(
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
is_unet_compiled = is_compiled_module(self.unet)
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
||||
torch._inductor.cudagraph_mark_step_begin()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -36,7 +36,7 @@ from ...models.attention_processor import (
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
|
||||
@@ -1144,8 +1144,15 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
is_unet_compiled = is_compiled_module(self.unet)
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
||||
torch._inductor.cudagraph_mark_step_begin()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -814,7 +814,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
|
||||
).sample
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
||||
@@ -1084,9 +1084,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
# Forward upsample size to force interpolation output size.
|
||||
forward_upsample_size = True
|
||||
for dim in sample.shape[-2:]:
|
||||
if dim % default_overall_up_factor != 0:
|
||||
# Forward upsample size to force interpolation output size.
|
||||
forward_upsample_size = True
|
||||
break
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||
# expects mask of shape:
|
||||
|
||||
Reference in New Issue
Block a user