mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
committed by
GitHub
parent
88fa6b7d68
commit
365ff8f76d
@@ -149,7 +149,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
timestep = timestep[None]
|
||||
|
||||
timestep_embed = self.time_proj(timestep)[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]])
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
|
||||
@@ -91,10 +91,14 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
)
|
||||
sample_size = int(sample_size)
|
||||
|
||||
audio = torch.randn((batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device)
|
||||
dtype = next(iter(self.unet.parameters())).dtype
|
||||
audio = torch.randn(
|
||||
(batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype
|
||||
)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
|
||||
self.scheduler.timesteps = self.scheduler.timesteps.to(dtype)
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
@@ -103,7 +107,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
# 2. compute previous image: x_t -> t_t-1
|
||||
audio = self.scheduler.step(model_output, t, audio).prev_sample
|
||||
|
||||
audio = audio.clamp(-1, 1).cpu().numpy()
|
||||
audio = audio.clamp(-1, 1).float().cpu().numpy()
|
||||
|
||||
audio = audio[:, :, :original_sample_size]
|
||||
|
||||
|
||||
@@ -99,3 +99,20 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||
assert audio.shape == (1, 2, pipe.unet.sample_size)
|
||||
expected_slice = np.array([-0.1576, -0.1526, -0.127, -0.2699, -0.2762, -0.2487])
|
||||
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_dance_diffusion_fp16(self):
|
||||
device = torch_device
|
||||
|
||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe(generator=generator, num_inference_steps=100, sample_length_in_s=4.096)
|
||||
audio = output.audios
|
||||
|
||||
audio_slice = audio[0, -3:, -3:]
|
||||
|
||||
assert audio.shape == (1, 2, pipe.unet.sample_size)
|
||||
expected_slice = np.array([-0.1693, -0.1698, -0.1447, -0.3044, -0.3203, -0.2937])
|
||||
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user