From a17832b2d96c0df9b41ce2faab5659ef46916c39 Mon Sep 17 00:00:00 2001 From: chaowenguo Date: Fri, 3 Jan 2025 08:00:02 -0800 Subject: [PATCH] add pythor_xla support for render a video (#10443) * Update rerender_a_video.py * Update rerender_a_video.py * make style --------- Co-authored-by: hlky --- examples/community/rerender_a_video.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index d9c616ab5e..cae5fcb2b9 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -30,10 +30,17 @@ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import BaseOutput, deprecate, logging +from diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging from diffusers.utils.torch_utils import is_compiled_module, randn_tensor +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1100,6 +1107,9 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + return latents if mask_start_t <= mask_end_t: