mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix BaseOutput initialization from dict (#570)
* Fix BaseOutput initialization from dict * style * Simplify post-init, add tests * remove debug
This commit is contained in:
@@ -59,10 +59,17 @@ class BaseOutput(OrderedDict):
|
||||
if not len(class_fields):
|
||||
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
||||
|
||||
for field in class_fields:
|
||||
v = getattr(self, field.name)
|
||||
if v is not None:
|
||||
self[field.name] = v
|
||||
first_field = getattr(self, class_fields[0].name)
|
||||
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
||||
|
||||
if other_fields_are_none and isinstance(first_field, dict):
|
||||
for key, value in first_field.items():
|
||||
self[key] = value
|
||||
else:
|
||||
for field in class_fields:
|
||||
v = getattr(self, field.name)
|
||||
if v is not None:
|
||||
self[field.name] = v
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
60
tests/test_outputs.py
Normal file
60
tests/test_outputs.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomOutput(BaseOutput):
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
class ConfigTester(unittest.TestCase):
|
||||
def test_outputs_single_attribute(self):
|
||||
outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4))
|
||||
|
||||
# check every way of getting the attribute
|
||||
assert isinstance(outputs.images, np.ndarray)
|
||||
assert outputs.images.shape == (1, 3, 4, 4)
|
||||
assert isinstance(outputs["images"], np.ndarray)
|
||||
assert outputs["images"].shape == (1, 3, 4, 4)
|
||||
assert isinstance(outputs[0], np.ndarray)
|
||||
assert outputs[0].shape == (1, 3, 4, 4)
|
||||
|
||||
# test with a non-tensor attribute
|
||||
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
|
||||
|
||||
# check every way of getting the attribute
|
||||
assert isinstance(outputs.images, list)
|
||||
assert isinstance(outputs.images[0], PIL.Image.Image)
|
||||
assert isinstance(outputs["images"], list)
|
||||
assert isinstance(outputs["images"][0], PIL.Image.Image)
|
||||
assert isinstance(outputs[0], list)
|
||||
assert isinstance(outputs[0][0], PIL.Image.Image)
|
||||
|
||||
def test_outputs_dict_init(self):
|
||||
# test output reinitialization with a `dict` for compatibility with `accelerate`
|
||||
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)})
|
||||
|
||||
# check every way of getting the attribute
|
||||
assert isinstance(outputs.images, np.ndarray)
|
||||
assert outputs.images.shape == (1, 3, 4, 4)
|
||||
assert isinstance(outputs["images"], np.ndarray)
|
||||
assert outputs["images"].shape == (1, 3, 4, 4)
|
||||
assert isinstance(outputs[0], np.ndarray)
|
||||
assert outputs[0].shape == (1, 3, 4, 4)
|
||||
|
||||
# test with a non-tensor attribute
|
||||
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]})
|
||||
|
||||
# check every way of getting the attribute
|
||||
assert isinstance(outputs.images, list)
|
||||
assert isinstance(outputs.images[0], PIL.Image.Image)
|
||||
assert isinstance(outputs["images"], list)
|
||||
assert isinstance(outputs["images"][0], PIL.Image.Image)
|
||||
assert isinstance(outputs[0], list)
|
||||
assert isinstance(outputs[0][0], PIL.Image.Image)
|
||||
Reference in New Issue
Block a user