1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Scheduling fixes on MPS (#10549)

* use np.int32 in scheduling

* test_add_noise_device

* -np.int32, fixes
This commit is contained in:
hlky
2025-01-16 17:45:03 +00:00
committed by GitHub
parent 9e1b8a0017
commit 08e62fe0c2
4 changed files with 5 additions and 5 deletions

View File

@@ -342,7 +342,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
self.timesteps = timesteps.to(device=device)
self.timesteps = timesteps.to(device=device, dtype=torch.float32)
# empty dt and derivative
self.prev_derivative = None

View File

@@ -311,7 +311,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

View File

@@ -99,7 +99,7 @@ class LCMSchedulerTest(SchedulerCommonTest):
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
noise = torch.randn_like(scaled_sample).to(torch_device)
noise = torch.randn(scaled_sample.shape).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)

View File

@@ -361,7 +361,7 @@ class SchedulerCommonTest(unittest.TestCase):
if isinstance(t, torch.Tensor):
num_dims = len(sample.shape)
# pad t with 1s to match num_dims
t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device).to(sample.dtype)
t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device, dtype=sample.dtype)
return sample * t / (t + 1)
@@ -722,7 +722,7 @@ class SchedulerCommonTest(unittest.TestCase):
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
noise = torch.randn_like(scaled_sample).to(torch_device)
noise = torch.randn(scaled_sample.shape).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)