mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* [Tests] parallelize * finish folder structuring * Parallelize tests more * Correct saving of pipelines * make sure logging level is correct * try again * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
67 lines
2.7 KiB
Python
67 lines
2.7 KiB
Python
import inspect
|
|
|
|
from diffusers.utils import is_flax_available
|
|
from diffusers.utils.testing_utils import require_flax
|
|
|
|
|
|
if is_flax_available():
|
|
import jax
|
|
|
|
|
|
@require_flax
|
|
class FlaxModelTesterMixin:
|
|
def test_output(self):
|
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
|
|
|
model = self.model_class(**init_dict)
|
|
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
|
jax.lax.stop_gradient(variables)
|
|
|
|
output = model.apply(variables, inputs_dict["sample"])
|
|
|
|
if isinstance(output, dict):
|
|
output = output.sample
|
|
|
|
self.assertIsNotNone(output)
|
|
expected_shape = inputs_dict["sample"].shape
|
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
|
|
|
def test_forward_with_norm_groups(self):
|
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
|
|
|
init_dict["norm_num_groups"] = 16
|
|
init_dict["block_out_channels"] = (16, 32)
|
|
|
|
model = self.model_class(**init_dict)
|
|
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
|
jax.lax.stop_gradient(variables)
|
|
|
|
output = model.apply(variables, inputs_dict["sample"])
|
|
|
|
if isinstance(output, dict):
|
|
output = output.sample
|
|
|
|
self.assertIsNotNone(output)
|
|
expected_shape = inputs_dict["sample"].shape
|
|
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
|
|
|
def test_deprecated_kwargs(self):
|
|
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
|
|
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
|
|
|
|
if has_kwarg_in_model_class and not has_deprecated_kwarg:
|
|
raise ValueError(
|
|
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
|
|
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
|
|
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
|
|
" [<deprecated_argument>]`"
|
|
)
|
|
|
|
if not has_kwarg_in_model_class and has_deprecated_kwarg:
|
|
raise ValueError(
|
|
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
|
|
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
|
|
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
|
|
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
|
)
|