1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Raise an error when trying to use SD Cascade Decoder with dtype bfloat16 and torch < 2.2 (#7244)

update
This commit is contained in:
Dhruv Nair
2024-03-07 17:55:46 +05:30
committed by GitHub
parent 196835695e
commit 39dfb7abbd

View File

@@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import logging, replace_example_docstring
from ...utils import is_torch_version, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
@@ -361,6 +361,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
device = self._execution_device
dtype = self.decoder.dtype
self._guidance_scale = guidance_scale
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.")
# 1. Check inputs. Raise error if not correct
self.check_inputs(