1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Dance Diffusion] FP16 (#980)

* add in fp16

* up
This commit is contained in:
Patrick von Platen
2022-10-25 19:33:43 +02:00
committed by GitHub
parent 88fa6b7d68
commit 365ff8f76d
3 changed files with 24 additions and 3 deletions

View File

@@ -149,7 +149,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
timestep = timestep[None]
timestep_embed = self.time_proj(timestep)[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]])
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
# 2. down
down_block_res_samples = ()

View File

@@ -91,10 +91,14 @@ class DanceDiffusionPipeline(DiffusionPipeline):
)
sample_size = int(sample_size)
audio = torch.randn((batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device)
dtype = next(iter(self.unet.parameters())).dtype
audio = torch.randn(
(batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype
)
# set step values
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
self.scheduler.timesteps = self.scheduler.timesteps.to(dtype)
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
@@ -103,7 +107,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
# 2. compute previous image: x_t -> t_t-1
audio = self.scheduler.step(model_output, t, audio).prev_sample
audio = audio.clamp(-1, 1).cpu().numpy()
audio = audio.clamp(-1, 1).float().cpu().numpy()
audio = audio[:, :, :original_sample_size]

View File

@@ -99,3 +99,20 @@ class PipelineIntegrationTests(unittest.TestCase):
assert audio.shape == (1, 2, pipe.unet.sample_size)
expected_slice = np.array([-0.1576, -0.1526, -0.127, -0.2699, -0.2762, -0.2487])
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
def test_dance_diffusion_fp16(self):
device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(generator=generator, num_inference_steps=100, sample_length_in_s=4.096)
audio = output.audios
audio_slice = audio[0, -3:, -3:]
assert audio.shape == (1, 2, pipe.unet.sample_size)
expected_slice = np.array([-0.1693, -0.1698, -0.1447, -0.3044, -0.3203, -0.2937])
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2