mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
dynamic threshold sampling bug fixes and docs (#3003)
dynamic threshold sampling bug fix and docs
This commit is contained in:
@@ -201,15 +201,38 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, height, width = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * height * width)
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, height, width)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -315,14 +338,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
# 4. Clip or threshold "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
elif self.config.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self._get_variance(timestep, prev_timestep)
|
||||
|
||||
@@ -241,15 +241,38 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return variance
|
||||
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, height, width = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * height * width)
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, height, width)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -309,14 +332,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
# 3. Clip or threshold "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
elif self.config.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
||||
|
||||
@@ -196,15 +196,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, height, width = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * height * width)
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, height, width)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
@@ -236,11 +259,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
|
||||
@@ -207,15 +207,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, height, width = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * height * width)
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, height, width)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
@@ -256,11 +279,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
|
||||
@@ -239,15 +239,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, height, width = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * height * width)
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, height, width)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
@@ -288,11 +311,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
|
||||
@@ -212,15 +212,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, height, width = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * height * width)
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, height, width)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
@@ -253,11 +276,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
else:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
|
||||
@@ -201,7 +201,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.6405) < 1e-3
|
||||
assert abs(result_mean.item() - 1.1364) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
|
||||
Reference in New Issue
Block a user