mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[WAN] fix recompilation issues (#11475)
* [tests] Add torch.compile() test for WanTransformer3DModel * fix wan recompilation issues. * style --------- Co-authored-by: tongyu0924 <winnie920924@gmail.com>
This commit is contained in:
@@ -202,8 +202,8 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
self.freqs = self.freqs.to(hidden_states.device)
|
||||
freqs = self.freqs.split_with_sizes(
|
||||
freqs = self.freqs.to(hidden_states.device)
|
||||
freqs = freqs.split_with_sizes(
|
||||
[
|
||||
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
|
||||
self.attention_head_dim // 6,
|
||||
|
||||
@@ -17,7 +17,14 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
is_torch_compile,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
@@ -79,3 +86,18 @@ class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
Reference in New Issue
Block a user