diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 29a79d391e..dbce17868d 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -201,15 +201,38 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ @@ -315,14 +338,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ) # 4. Clip or threshold "predicted x_0" - if self.config.clip_sample: + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) - if self.config.thresholding: - pred_original_sample = self._threshold_sample(pred_original_sample) - # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 206294066c..e047a553a2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -241,15 +241,38 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return variance def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def step( self, @@ -309,14 +332,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ) # 3. Clip or threshold "predicted x_0" - if self.config.clip_sample: + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) - if self.config.thresholding: - pred_original_sample = self._threshold_sample(pred_original_sample) - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 39f8f17df5..acda0271ec 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -196,15 +196,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor @@ -236,11 +259,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ) if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - x0_pred = self._threshold_sample(x0_pred).type(orig_dtype) + x0_pred = self._threshold_sample(x0_pred) if self.config.algorithm_type == "deis": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 474d9b0d73..320047f00a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -207,15 +207,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor @@ -256,11 +279,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ) if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - x0_pred = self._threshold_sample(x0_pred).type(orig_dtype) + x0_pred = self._threshold_sample(x0_pred) + return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index a02171a2df..6e014f62a1 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -239,15 +239,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor @@ -288,11 +311,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ) if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - x0_pred = self._threshold_sample(x0_pred).type(orig_dtype) + x0_pred = self._threshold_sample(x0_pred) + return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index e4f38d0f5d..7bee907929 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -212,15 +212,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = ( - sample.flatten(1) - .abs() - .quantile(self.config.dynamic_thresholding_ratio, dim=1) - .clamp_min(self.config.sample_max_value) - .view(-1, *([1] * (sample.ndim - 1))) - ) - return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor @@ -253,11 +276,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ) if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - x0_pred = self._threshold_sample(x0_pred).type(orig_dtype) + x0_pred = self._threshold_sample(x0_pred) + return x0_pred else: if self.config.prediction_type == "epsilon": diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 295bbe8827..9da43714f5 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -201,7 +201,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.6405) < 1e-3 + assert abs(result_mean.item() - 1.1364) < 1e-3 def test_full_loop_with_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction")