mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[tests] Add torch.compile test for UNet2DConditionModel (#11537)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -53,7 +53,12 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import (
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
UNetTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@@ -351,7 +356,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(
|
||||
ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
|
||||
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
|
||||
):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
|
||||
Reference in New Issue
Block a user