mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support bfloat16 for Upsample2D (#9480)
* Support bfloat16 for Upsample2D * Add test and use is_torch_version * Resolve comments and add decorator * Simplify require_torch_version_greater_equal decorator * Run make style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils.import_utils import is_torch_version
|
||||
from .normalization import RMSNorm
|
||||
|
||||
|
||||
@@ -151,11 +152,10 @@ class Upsample2D(nn.Module):
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
|
||||
# https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
@@ -170,8 +170,8 @@ class Upsample2D(nn.Module):
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
# Cast back to original dtype
|
||||
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
|
||||
@@ -252,6 +252,18 @@ def require_torch_2(test_case):
|
||||
)
|
||||
|
||||
|
||||
def require_torch_version_greater_equal(torch_version):
|
||||
"""Decorator marking a test that requires torch with a specific version or greater."""
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
|
||||
|
||||
@@ -27,6 +27,7 @@ from diffusers.models.transformers.transformer_2d import Transformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_manual_seed,
|
||||
require_torch_accelerator_with_fp64,
|
||||
require_torch_version_greater_equal,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -120,6 +121,21 @@ class Upsample2DBlockTests(unittest.TestCase):
|
||||
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
@require_torch_version_greater_equal("2.1")
|
||||
def test_upsample_bfloat16(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32).to(torch.bfloat16)
|
||||
upsample = Upsample2D(channels=32, use_conv=False)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor(
|
||||
[-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254], dtype=torch.bfloat16
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_conv(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
|
||||
Reference in New Issue
Block a user