From 75a636da4882771ca8834b804f767daa9394ffa8 Mon Sep 17 00:00:00 2001 From: baymax591 Date: Tue, 21 Jan 2025 03:35:24 +0800 Subject: [PATCH] bugfix for npu not support float64 (#10123) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bugfix for npu not support float64 * is_mps is_npu --------- Co-authored-by: 白超 Co-authored-by: hlky --- examples/community/fresco_v2v.py | 5 +++-- examples/community/matryoshka.py | 5 +++-- .../pixart/pipeline_pixart_alpha_controlnet.py | 5 +++-- .../promptdiffusion/promptdiffusioncontrolnet.py | 5 +++-- src/diffusers/models/controlnets/controlnet.py | 5 +++-- src/diffusers/models/controlnets/controlnet_sparsectrl.py | 5 +++-- src/diffusers/models/controlnets/controlnet_union.py | 5 +++-- src/diffusers/models/controlnets/controlnet_xs.py | 5 +++-- src/diffusers/models/unets/unet_2d_condition.py | 5 +++-- src/diffusers/models/unets/unet_3d_condition.py | 5 +++-- src/diffusers/models/unets/unet_i2vgen_xl.py | 5 +++-- src/diffusers/models/unets/unet_motion_model.py | 5 +++-- src/diffusers/models/unets/unet_spatio_temporal_condition.py | 5 +++-- src/diffusers/pipelines/audioldm2/modeling_audioldm2.py | 5 +++-- .../deprecated/versatile_diffusion/modeling_text_unet.py | 5 +++-- src/diffusers/pipelines/dit/pipeline_dit.py | 5 +++-- src/diffusers/pipelines/latte/pipeline_latte.py | 5 +++-- src/diffusers/pipelines/lumina/pipeline_lumina.py | 5 +++-- src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py | 5 +++-- .../pipelines/pixart_alpha/pipeline_pixart_alpha.py | 5 +++-- .../pipelines/pixart_alpha/pipeline_pixart_sigma.py | 5 +++-- 21 files changed, 63 insertions(+), 42 deletions(-) diff --git a/examples/community/fresco_v2v.py b/examples/community/fresco_v2v.py index 2784e2f238..d6c2683f1d 100644 --- a/examples/community/fresco_v2v.py +++ b/examples/community/fresco_v2v.py @@ -404,10 +404,11 @@ def my_forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index f80b29456c..1d7a367ecc 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -2806,10 +2806,11 @@ class MatryoshkaUNet2DConditionModel( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py index d7f882974a..4065a854c2 100644 --- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py +++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py @@ -1031,10 +1031,11 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) diff --git a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py index 6b1826a1c9..7853695f05 100644 --- a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py +++ b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py @@ -258,10 +258,11 @@ class PromptDiffusionControlNetModel(ControlNetModel): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index bd00f6dd19..1453aaf436 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -740,10 +740,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index fd599c10b2..807cbd339e 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -671,10 +671,11 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index fc80da7623..1bf176101c 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -681,10 +681,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 11ad676ec9..8a8901d82d 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -1088,10 +1088,11 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index e488f5897e..2b896f89e4 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -915,10 +915,11 @@ class UNet2DConditionModel( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 3081fdc470..56739ac24c 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -624,10 +624,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 6ab3a577b8..d5d98c2563 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -575,10 +575,11 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timesteps, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index ddc3e41c34..1c07a0760f 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2114,10 +2114,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 308b9e01c5..172c1e6bbb 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -402,10 +402,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index 63d3957ae1..a33e265687 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -768,10 +768,11 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 0fd8875a88..4d9e50e3a2 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1163,10 +1163,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py index cf5ebbce2b..8aee0fadaf 100644 --- a/src/diffusers/pipelines/dit/pipeline_dit.py +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -187,10 +187,11 @@ class DiTPipeline(DiffusionPipeline): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(timesteps, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(latent_model_input.device) diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 1b70650dfa..ce4ca313eb 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -798,10 +798,11 @@ class LattePipeline(DiffusionPipeline): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 52bb654603..5b37e9a503 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -806,10 +806,11 @@ class LuminaText2ImgPipeline(DiffusionPipeline): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor( [current_timestep], dtype=dtype, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index d927a7961a..affda7e18a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -807,10 +807,11 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 46a7337051..b550a442fe 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -907,10 +907,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 356ba3a29a..7f10ee89ee 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -822,10 +822,11 @@ class PixArtSigmaPipeline(DiffusionPipeline): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device)