diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py index a30af53f77..cef63cf7e6 100644 --- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py +++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -34,6 +34,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .modeling_stable_audio import StableAudioProjectionModel + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -732,7 +733,7 @@ class StableAudioPipeline(DiffusionPipeline): if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - + if XLA_AVAILABLE: xm.mark_step()