1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[WAN] fix recompilation issues (#11475)

* [tests] Add torch.compile() test for WanTransformer3DModel

* fix wan recompilation issues.

* style

---------

Co-authored-by: tongyu0924 <winnie920924@gmail.com>
This commit is contained in:
Sayak Paul
2025-05-01 14:29:08 +08:00
committed by GitHub
parent 06beecafc5
commit d70f8ee18b
2 changed files with 25 additions and 3 deletions

View File

@@ -202,8 +202,8 @@ class WanRotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
self.freqs = self.freqs.to(hidden_states.device)
freqs = self.freqs.split_with_sizes(
freqs = self.freqs.to(hidden_states.device)
freqs = freqs.split_with_sizes(
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,

View File

@@ -17,7 +17,14 @@ import unittest
import torch
from diffusers import WanTransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_compile,
require_torch_2,
require_torch_gpu,
slow,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
@@ -79,3 +86,18 @@ class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)