mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SVD] Return np.ndarray when output_type="np" (#6507)
[SVD] Fix output_type="np"
This commit is contained in:
@@ -52,6 +52,9 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
return np.stack(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
@@ -185,6 +185,23 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
def test_np_output_type(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["output_type"] = "np"
|
||||
output = pipe(**inputs).frames
|
||||
self.assertTrue(isinstance(output, np.ndarray))
|
||||
self.assertEqual(len(output.shape), 5)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
Reference in New Issue
Block a user