From 5fb3a985173efaae7ff381b9040c386751d643da Mon Sep 17 00:00:00 2001 From: fancy45daddy <124528204+fancy45daddy@users.noreply.github.com> Date: Mon, 16 Dec 2024 01:05:50 -0800 Subject: [PATCH] Update pipeline_controlnet.py add support for pytorch_xla (#10222) * Update pipeline_controlnet.py * make style --------- Co-authored-by: hlky --- .../pipelines/controlnet/pipeline_controlnet.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 486f9fb764..582f51ab48 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -31,6 +31,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -42,6 +43,13 @@ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1323,6 +1331,8 @@ class StableDiffusionControlNetPipeline( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: