1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Remove mps workaround for fp16 GELU, which is now supported natively (#10133)

* Remove mps workaround for fp16 GELU, which is now supported natively

---------

Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
skotapati
2024-12-12 18:05:59 -08:00
committed by GitHub
parent bdbaea8f64
commit ec9bfa9e14

View File

@@ -18,7 +18,7 @@ import torch.nn.functional as F
from torch import nn
from ..utils import deprecate
from ..utils.import_utils import is_torch_npu_available
from ..utils.import_utils import is_torch_npu_available, is_torch_version
if is_torch_npu_available():
@@ -79,10 +79,10 @@ class GELU(nn.Module):
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
return F.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
@@ -105,10 +105,10 @@ class GEGLU(nn.Module):
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
return F.gelu(gate)
def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None: