diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 52d056ae96..37ad2ba91e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -419,6 +419,8 @@ class ResnetBlock2D(nn.Module): self.nonlinearity = Mish() elif non_linearity == "silu": self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() self.upsample = self.downsample = None if self.up: