mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Scheduling fixes on MPS (#10549)
* use np.int32 in scheduling * test_add_noise_device * -np.int32, fixes
This commit is contained in:
@@ -342,7 +342,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps = torch.from_numpy(timesteps)
|
||||
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.timesteps = timesteps.to(device=device, dtype=torch.float32)
|
||||
|
||||
# empty dt and derivative
|
||||
self.prev_derivative = None
|
||||
|
||||
@@ -311,7 +311,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@@ -99,7 +99,7 @@ class LCMSchedulerTest(SchedulerCommonTest):
|
||||
scaled_sample = scheduler.scale_model_input(sample, 0.0)
|
||||
self.assertEqual(sample.shape, scaled_sample.shape)
|
||||
|
||||
noise = torch.randn_like(scaled_sample).to(torch_device)
|
||||
noise = torch.randn(scaled_sample.shape).to(torch_device)
|
||||
t = scheduler.timesteps[5][None]
|
||||
noised = scheduler.add_noise(scaled_sample, noise, t)
|
||||
self.assertEqual(noised.shape, scaled_sample.shape)
|
||||
|
||||
@@ -361,7 +361,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
if isinstance(t, torch.Tensor):
|
||||
num_dims = len(sample.shape)
|
||||
# pad t with 1s to match num_dims
|
||||
t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device).to(sample.dtype)
|
||||
t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device, dtype=sample.dtype)
|
||||
|
||||
return sample * t / (t + 1)
|
||||
|
||||
@@ -722,7 +722,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scaled_sample = scheduler.scale_model_input(sample, 0.0)
|
||||
self.assertEqual(sample.shape, scaled_sample.shape)
|
||||
|
||||
noise = torch.randn_like(scaled_sample).to(torch_device)
|
||||
noise = torch.randn(scaled_sample.shape).to(torch_device)
|
||||
t = scheduler.timesteps[5][None]
|
||||
noised = scheduler.add_noise(scaled_sample, noise, t)
|
||||
self.assertEqual(noised.shape, scaled_sample.shape)
|
||||
|
||||
Reference in New Issue
Block a user