mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix unipc use_karras_sigmas exception - fixes huggingface/diffusers#4580 (#4581)
* Fix unipc karras sigmas exception - fixes huggingface/diffusers#4580 * Add unipc scheduler tests for karras sigmas
This commit is contained in:
@@ -278,6 +278,44 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
@@ -208,12 +208,24 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert abs(result_mean.item() - 0.2464) < 1e-3
|
||||
|
||||
def test_full_loop_with_karras(self):
|
||||
sample = self.full_loop(use_karras_sigmas=True)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.2925) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.1014) < 1e-3
|
||||
|
||||
def test_full_loop_with_karras_and_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.1966) < 1e-3
|
||||
|
||||
def test_fp16_support(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
|
||||
|
||||
Reference in New Issue
Block a user