1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Nathan Lambert
2022-10-27 10:59:59 -07:00
parent 864d7b846e
commit f163bccc4e

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -42,10 +41,10 @@ class ALDSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None
class ALDScheduler(SchedulerMixin, ConfigMixin):
"""
The Annealed Langevin Dynamics sampler was popularized in the paper on Noise Conditional Score Networks (NCSNs). For more details, refer to the paper https://arxiv.org/abs/1907.05600
The Annealed Langevin Dynamics sampler was popularized in the paper on Noise Conditional Score Networks (NCSNs).
For more details, refer to the paper https://arxiv.org/abs/1907.05600
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
@@ -96,9 +95,7 @@ class ALDScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def set_timesteps(
self, num_inference_steps: int, device: Union[str, torch.device] = None
):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -114,9 +111,7 @@ class ALDScheduler(SchedulerMixin, ConfigMixin):
)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None
):
def set_sigmas(self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None):
"""
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
@@ -137,8 +132,9 @@ class ALDScheduler(SchedulerMixin, ConfigMixin):
self.set_timesteps(num_inference_steps)
self.sigmas = torch.tensor(
torch.exp(torch.linspace(torch.log(sigma_min), torch.log(sigma_max),
num_inference_steps)), dtype=torch.float32)
torch.exp(torch.linspace(torch.log(sigma_min), torch.log(sigma_max), num_inference_steps)),
dtype=torch.float32,
)
self.final_noise_sigma = self.sigmas[-1]