1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

dynamic threshold sampling bug fixes and docs (#3003)

dynamic threshold sampling bug fix and docs
This commit is contained in:
Will Berman
2023-04-09 21:43:40 -07:00
committed by GitHub
parent dcfa6e1d20
commit 2cbdc586de
7 changed files with 206 additions and 83 deletions

View File

@@ -201,15 +201,38 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
@@ -315,14 +338,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
)
# 4. Clip or threshold "predicted x_0"
if self.config.clip_sample:
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep)

View File

@@ -241,15 +241,38 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return variance
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def step(
self,
@@ -309,14 +332,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)
# 3. Clip or threshold "predicted x_0"
if self.config.clip_sample:
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t

View File

@@ -196,15 +196,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
@@ -236,11 +259,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
)
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
x0_pred = self._threshold_sample(x0_pred)
if self.config.algorithm_type == "deis":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]

View File

@@ -207,15 +207,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
@@ -256,11 +279,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":

View File

@@ -239,15 +239,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
@@ -288,11 +311,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
)
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":

View File

@@ -212,15 +212,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
@@ -253,11 +276,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else:
if self.config.prediction_type == "epsilon":

View File

@@ -201,7 +201,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.6405) < 1e-3
assert abs(result_mean.item() - 1.1364) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")