From 8d891e6e1bc02fb42a891d95cfa8a315dadb3b5a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 18 Dec 2023 18:21:17 +0100 Subject: [PATCH] [Torch Compile] Fix torch compile for svd vae (#6217) --- .../pipeline_stable_video_diffusion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index a82f5379e7..988623ca65 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -25,7 +25,7 @@ from ...image_processor import VaeImageProcessor from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import BaseOutput, logging -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -211,7 +211,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): latents = 1 / self.vae.config.scaling_factor * latents - accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) # decode decode_chunk_size frames at a time to avoid OOM frames = []