From 9cfd4ef0746076febb589788ec47df9e2db43d65 Mon Sep 17 00:00:00 2001 From: Charles Bensimon Date: Fri, 29 Sep 2023 16:35:16 +0200 Subject: [PATCH] Make `BaseOutput` dataclasses picklable (#5234) * Make BaseOutput dataclasses picklable * make style * Test * Empty commit * Simpler and safer --- src/diffusers/utils/outputs.py | 9 ++++++++- tests/others/test_outputs.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 37b11561d1..802c699eb9 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -16,7 +16,7 @@ Generic utilities """ from collections import OrderedDict -from dataclasses import fields +from dataclasses import fields, is_dataclass from typing import Any, Tuple import numpy as np @@ -101,6 +101,13 @@ class BaseOutput(OrderedDict): # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) + def __reduce__(self): + if not is_dataclass(self): + return super().__reduce__() + callable, _args, *remaining = super().__reduce__() + args = tuple(getattr(self, field.name) for field in fields(self)) + return callable, args, *remaining + def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. diff --git a/tests/others/test_outputs.py b/tests/others/test_outputs.py index 50cbd1d54e..492e71f0ba 100644 --- a/tests/others/test_outputs.py +++ b/tests/others/test_outputs.py @@ -1,3 +1,4 @@ +import pickle as pkl import unittest from dataclasses import dataclass from typing import List, Union @@ -58,3 +59,13 @@ class ConfigTester(unittest.TestCase): assert isinstance(outputs["images"][0], PIL.Image.Image) assert isinstance(outputs[0], list) assert isinstance(outputs[0][0], PIL.Image.Image) + + def test_outputs_serialization(self): + outputs_orig = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) + serialized = pkl.dumps(outputs_orig) + outputs_copy = pkl.loads(serialized) + + # Check original and copy are equal + assert dir(outputs_orig) == dir(outputs_copy) + assert dict(outputs_orig) == dict(outputs_copy) + assert vars(outputs_orig) == vars(outputs_copy)