From c7617e482a522173ea6f922223aa010058552af8 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Tue, 24 Oct 2023 02:31:47 -0700 Subject: [PATCH] Register BaseOutput subclasses as supported torch.utils._pytree nodes (#5459) * Register BaseOutput subclasses as supported torch.utils._pytree nodes * lint --------- Co-authored-by: Dhruv Nair --- src/diffusers/utils/outputs.py | 15 +++++++++++++++ tests/others/test_outputs.py | 22 ++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 802c699eb9..a057b506ae 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -51,6 +51,21 @@ class BaseOutput(OrderedDict): """ + def __init_subclass__(cls) -> None: + """Register subclasses as pytree nodes. + + This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with + `static_graph=True` with modules that output `ModelOutput` subclasses. + """ + if is_torch_available(): + import torch.utils._pytree + + torch.utils._pytree._register_pytree_node( + cls, + torch.utils._pytree._dict_flatten, + lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), + ) + def __post_init__(self): class_fields = fields(self) diff --git a/tests/others/test_outputs.py b/tests/others/test_outputs.py index 492e71f0ba..cf709d93f7 100644 --- a/tests/others/test_outputs.py +++ b/tests/others/test_outputs.py @@ -7,6 +7,7 @@ import numpy as np import PIL.Image from diffusers.utils.outputs import BaseOutput +from diffusers.utils.testing_utils import require_torch @dataclass @@ -69,3 +70,24 @@ class ConfigTester(unittest.TestCase): assert dir(outputs_orig) == dir(outputs_copy) assert dict(outputs_orig) == dict(outputs_copy) assert vars(outputs_orig) == vars(outputs_copy) + + @require_torch + def test_torch_pytree(self): + # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves) + # this is important for DistributedDataParallel gradient synchronization with static_graph=True + import torch + import torch.utils._pytree + + data = np.random.rand(1, 3, 4, 4) + x = CustomOutput(images=data) + self.assertFalse(torch.utils._pytree._is_leaf(x)) + + expected_flat_outs = [data] + expected_tree_spec = torch.utils._pytree.TreeSpec(CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()]) + + actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) + self.assertEqual(expected_flat_outs, actual_flat_outs) + self.assertEqual(expected_tree_spec, actual_tree_spec) + + unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) + self.assertEqual(x, unflattened_x)