mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
make fix-copies
This commit is contained in:
@@ -87,6 +87,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
use_flow_sigmas: bool = False,
|
||||
):
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
if solver_type in ["logrho", "bh1", "bh2"]:
|
||||
@@ -152,23 +153,19 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not isinstance(sigma, torch.Tensor):
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
return sigma.atan() / math.pi * 2
|
||||
if self.config.use_flow_sigmas:
|
||||
c_noise = sigma / (sigma + 1)
|
||||
else:
|
||||
c_noise = sigma.atan() / math.pi * 2
|
||||
|
||||
return c_noise
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
|
||||
def precondition_outputs(self, sample, model_output, sigma):
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
if self.config.prediction_type == "epsilon":
|
||||
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
if self.config.use_flow_sigmas:
|
||||
return self._precondition_outputs_flow(sample, model_output, sigma)
|
||||
else:
|
||||
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
|
||||
|
||||
denoised = c_skip * sample + c_out * model_output
|
||||
|
||||
return denoised
|
||||
return self._precondition_outputs_edm(sample, model_output, sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
@@ -570,8 +567,42 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
if self.config.use_flow_sigmas:
|
||||
t = sigma / (sigma + 1)
|
||||
c_in = 1.0 - t
|
||||
else:
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._precondition_outputs_flow
|
||||
def _precondition_outputs_flow(self, sample, model_output, sigma):
|
||||
t = sigma / (sigma + 1)
|
||||
c_skip = 1.0 - t
|
||||
|
||||
if self.config.prediction_type == "epsilon":
|
||||
c_out = -t
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
c_out = t
|
||||
else:
|
||||
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
|
||||
|
||||
denoised = c_skip * sample + c_out * model_output
|
||||
return denoised
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._precondition_outputs_edm
|
||||
def _precondition_outputs_edm(self, sample, model_output, sigma):
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
if self.config.prediction_type == "epsilon":
|
||||
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
|
||||
|
||||
denoised = c_skip * sample + c_out * model_output
|
||||
return denoised
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -107,6 +107,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
use_flow_sigmas: bool = False,
|
||||
):
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
@@ -185,25 +186,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not isinstance(sigma, torch.Tensor):
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
c_noise = 0.25 * torch.log(sigma)
|
||||
if self.config.use_flow_sigmas:
|
||||
c_noise = sigma / (sigma + 1)
|
||||
else:
|
||||
c_noise = 0.25 * torch.log(sigma)
|
||||
|
||||
return c_noise
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
|
||||
def precondition_outputs(self, sample, model_output, sigma):
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
if self.config.prediction_type == "epsilon":
|
||||
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
if self.config.use_flow_sigmas:
|
||||
return self._precondition_outputs_flow(sample, model_output, sigma)
|
||||
else:
|
||||
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
|
||||
|
||||
denoised = c_skip * sample + c_out * model_output
|
||||
|
||||
return denoised
|
||||
return self._precondition_outputs_edm(sample, model_output, sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
@@ -705,8 +700,42 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
if self.config.use_flow_sigmas:
|
||||
t = sigma / (sigma + 1)
|
||||
c_in = 1.0 - t
|
||||
else:
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._precondition_outputs_flow
|
||||
def _precondition_outputs_flow(self, sample, model_output, sigma):
|
||||
t = sigma / (sigma + 1)
|
||||
c_skip = 1.0 - t
|
||||
|
||||
if self.config.prediction_type == "epsilon":
|
||||
c_out = -t
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
c_out = t
|
||||
else:
|
||||
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
|
||||
|
||||
denoised = c_skip * sample + c_out * model_output
|
||||
return denoised
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._precondition_outputs_edm
|
||||
def _precondition_outputs_edm(self, sample, model_output, sigma):
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
if self.config.prediction_type == "epsilon":
|
||||
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
|
||||
|
||||
denoised = c_skip * sample + c_out * model_output
|
||||
return denoised
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -422,6 +422,21 @@ class Cosmos2TextToImagePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Cosmos2VideoToWorldPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CosmosTextToWorldPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user