diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index f1e977f0db..6af3a5776a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -47,6 +47,7 @@ from diffusers.models.attention_processor import ( XFormersAttnProcessor, ) from diffusers.models.auto_model import AutoModel +from diffusers.models.modeling_outputs import BaseOutput from diffusers.training_utils import EMAModel from diffusers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -109,7 +110,7 @@ def check_if_lora_correctly_set(model) -> bool: def normalize_output(out): - out0 = out[0] if isinstance(out, tuple) else out + out0 = out[0] if isinstance(out, (BaseOutput, tuple)) else out return torch.stack(out0) if isinstance(out0, list) else out0