From ec9bfa9e148b7764137dd92247ce859d915abcb0 Mon Sep 17 00:00:00 2001 From: skotapati Date: Thu, 12 Dec 2024 18:05:59 -0800 Subject: [PATCH] Remove mps workaround for fp16 GELU, which is now supported natively (#10133) * Remove mps workaround for fp16 GELU, which is now supported natively --------- Co-authored-by: hlky --- src/diffusers/models/activations.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index f4318fc3cd..c1d4f0b46e 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn from ..utils import deprecate -from ..utils.import_utils import is_torch_npu_available +from ..utils.import_utils import is_torch_npu_available, is_torch_version if is_torch_npu_available(): @@ -79,10 +79,10 @@ class GELU(nn.Module): self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): + # fp16 gelu not supported on mps before torch 2.0 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + return F.gelu(gate, approximate=self.approximate) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) @@ -105,10 +105,10 @@ class GEGLU(nn.Module): self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): + # fp16 gelu not supported on mps before torch 2.0 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + return F.gelu(gate) def forward(self, hidden_states, *args, **kwargs): if len(args) > 0 or kwargs.get("scale", None) is not None: