mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
config fixes (#3060)
This commit is contained in:
@@ -115,8 +115,11 @@ class PipelineFastTests(unittest.TestCase):
|
||||
output = pipe(generator=generator, steps=4, return_dict=False)
|
||||
image_from_tuple = output[0][0]
|
||||
|
||||
assert audio.shape == (1, (self.dummy_unet.sample_size[1] - 1) * mel.hop_length)
|
||||
assert image.height == self.dummy_unet.sample_size[0] and image.width == self.dummy_unet.sample_size[1]
|
||||
assert audio.shape == (1, (self.dummy_unet.config.sample_size[1] - 1) * mel.hop_length)
|
||||
assert (
|
||||
image.height == self.dummy_unet.config.sample_size[0]
|
||||
and image.width == self.dummy_unet.config.sample_size[1]
|
||||
)
|
||||
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
||||
image_from_tuple_slice = np.frombuffer(image_from_tuple.tobytes(), dtype="uint8")[:10]
|
||||
expected_slice = np.array([69, 255, 255, 255, 0, 0, 77, 181, 12, 127])
|
||||
@@ -133,14 +136,14 @@ class PipelineFastTests(unittest.TestCase):
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
np.random.seed(0)
|
||||
raw_audio = np.random.uniform(-1, 1, ((dummy_vqvae_and_unet[0].sample_size[1] - 1) * mel.hop_length,))
|
||||
raw_audio = np.random.uniform(-1, 1, ((dummy_vqvae_and_unet[0].config.sample_size[1] - 1) * mel.hop_length,))
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
output = pipe(raw_audio=raw_audio, generator=generator, start_step=5, steps=10)
|
||||
image = output.images[0]
|
||||
|
||||
assert (
|
||||
image.height == self.dummy_vqvae_and_unet[0].sample_size[0]
|
||||
and image.width == self.dummy_vqvae_and_unet[0].sample_size[1]
|
||||
image.height == self.dummy_vqvae_and_unet[0].config.sample_size[0]
|
||||
and image.width == self.dummy_vqvae_and_unet[0].config.sample_size[1]
|
||||
)
|
||||
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
||||
expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121])
|
||||
@@ -183,8 +186,8 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||
audio = output.audios[0]
|
||||
image = output.images[0]
|
||||
|
||||
assert audio.shape == (1, (pipe.unet.sample_size[1] - 1) * pipe.mel.hop_length)
|
||||
assert image.height == pipe.unet.sample_size[0] and image.width == pipe.unet.sample_size[1]
|
||||
assert audio.shape == (1, (pipe.unet.config.sample_size[1] - 1) * pipe.mel.hop_length)
|
||||
assert image.height == pipe.unet.config.sample_size[0] and image.width == pipe.unet.config.sample_size[1]
|
||||
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
||||
expected_slice = np.array([151, 167, 154, 144, 122, 134, 121, 105, 70, 26])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user