mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Make BaseOutput dataclasses picklable (#5234)
* Make BaseOutput dataclasses picklable * make style * Test * Empty commit * Simpler and safer
This commit is contained in:
@@ -16,7 +16,7 @@ Generic utilities
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import fields
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -101,6 +101,13 @@ class BaseOutput(OrderedDict):
|
||||
# Don't call self.__setattr__ to avoid recursion errors
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __reduce__(self):
|
||||
if not is_dataclass(self):
|
||||
return super().__reduce__()
|
||||
callable, _args, *remaining = super().__reduce__()
|
||||
args = tuple(getattr(self, field.name) for field in fields(self))
|
||||
return callable, args, *remaining
|
||||
|
||||
def to_tuple(self) -> Tuple[Any]:
|
||||
"""
|
||||
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pickle as pkl
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
@@ -58,3 +59,13 @@ class ConfigTester(unittest.TestCase):
|
||||
assert isinstance(outputs["images"][0], PIL.Image.Image)
|
||||
assert isinstance(outputs[0], list)
|
||||
assert isinstance(outputs[0][0], PIL.Image.Image)
|
||||
|
||||
def test_outputs_serialization(self):
|
||||
outputs_orig = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
|
||||
serialized = pkl.dumps(outputs_orig)
|
||||
outputs_copy = pkl.loads(serialized)
|
||||
|
||||
# Check original and copy are equal
|
||||
assert dir(outputs_orig) == dir(outputs_copy)
|
||||
assert dict(outputs_orig) == dict(outputs_copy)
|
||||
assert vars(outputs_orig) == vars(outputs_copy)
|
||||
|
||||
Reference in New Issue
Block a user