mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Raise an error when trying to use SD Cascade Decoder with dtype bfloat16 and torch < 2.2 (#7244)
update
This commit is contained in:
@@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils import is_torch_version, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
@@ -361,6 +361,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
device = self._execution_device
|
||||
dtype = self.decoder.dtype
|
||||
self._guidance_scale = guidance_scale
|
||||
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
|
||||
raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.")
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
|
||||
Reference in New Issue
Block a user