From 39dfb7abbdbd3aea827cf228d1e00a80fc5cea80 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 7 Mar 2024 17:55:46 +0530 Subject: [PATCH] Raise an error when trying to use SD Cascade Decoder with dtype bfloat16 and torch < 2.2 (#7244) update --- .../pipelines/stable_cascade/pipeline_stable_cascade.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index 583ef3ee02..8a8d5b65e3 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_version, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel @@ -361,6 +361,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): device = self._execution_device dtype = self.decoder.dtype self._guidance_scale = guidance_scale + if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: + raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.") # 1. Check inputs. Raise error if not correct self.check_inputs(