mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Torch Compile] Fix torch compile for svd vae (#6217)
This commit is contained in:
committed by
GitHub
parent
cce1fe2d41
commit
8d891e6e1b
@@ -25,7 +25,7 @@ from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
||||
from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import BaseOutput, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
@@ -211,7 +211,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
|
||||
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
||||
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
||||
|
||||
# decode decode_chunk_size frames at a time to avoid OOM
|
||||
frames = []
|
||||
|
||||
Reference in New Issue
Block a user