mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Register BaseOutput subclasses as supported torch.utils._pytree nodes (#5459)
* Register BaseOutput subclasses as supported torch.utils._pytree nodes * lint --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -51,6 +51,21 @@ class BaseOutput(OrderedDict):
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
"""Register subclasses as pytree nodes.
|
||||
|
||||
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
|
||||
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||||
"""
|
||||
if is_torch_available():
|
||||
import torch.utils._pytree
|
||||
|
||||
torch.utils._pytree._register_pytree_node(
|
||||
cls,
|
||||
torch.utils._pytree._dict_flatten,
|
||||
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
class_fields = fields(self)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from diffusers.utils.testing_utils import require_torch
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -69,3 +70,24 @@ class ConfigTester(unittest.TestCase):
|
||||
assert dir(outputs_orig) == dir(outputs_copy)
|
||||
assert dict(outputs_orig) == dict(outputs_copy)
|
||||
assert vars(outputs_orig) == vars(outputs_copy)
|
||||
|
||||
@require_torch
|
||||
def test_torch_pytree(self):
|
||||
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
|
||||
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
|
||||
import torch
|
||||
import torch.utils._pytree
|
||||
|
||||
data = np.random.rand(1, 3, 4, 4)
|
||||
x = CustomOutput(images=data)
|
||||
self.assertFalse(torch.utils._pytree._is_leaf(x))
|
||||
|
||||
expected_flat_outs = [data]
|
||||
expected_tree_spec = torch.utils._pytree.TreeSpec(CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()])
|
||||
|
||||
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
|
||||
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
||||
self.assertEqual(expected_tree_spec, actual_tree_spec)
|
||||
|
||||
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||
self.assertEqual(x, unflattened_x)
|
||||
|
||||
Reference in New Issue
Block a user