1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

rename schedulers

This commit is contained in:
Patrick von Platen
2022-06-13 10:39:53 +02:00
parent 5c21d96284
commit 27266abc9f
9 changed files with 21 additions and 18 deletions

View File

@@ -35,7 +35,7 @@ Both models and schedulers should be load- and saveable from the Hub.
```python
import torch
from diffusers import UNetModel, GaussianDDPMScheduler
from diffusers import UNetModel, DDPMScheduler
import PIL
import numpy as np
import tqdm
@@ -44,7 +44,7 @@ generator = torch.manual_seed(0)
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load models
noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", tensor_format="pt")
noise_scheduler = DDPMScheduler.from_config("fusing/ddpm-lsun-church", tensor_format="pt")
unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise

View File

@@ -10,8 +10,10 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
from .schedulers import SchedulerMixin
from .schedulers.scheduling_ddim import DDIMScheduler
from .schedulers.scheduling_ddpm import DDPMScheduler
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.ddim import DDIMScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler

View File

@@ -1,12 +1,12 @@
#!/usr/bin/env python3
import torch
from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers import DDPMScheduler, UNetModel
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
diffusion = DDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)

View File

@@ -1,12 +1,12 @@
#!/usr/bin/env python3
import torch
from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers import DDPMScheduler, UNetModel
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
diffusion = DDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)

View File

@@ -16,8 +16,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .ddim import DDIMScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .schedulers_utils import SchedulerMixin
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .glide_ddim import GlideDDIMScheduler

View File

@@ -19,7 +19,7 @@ from ..configuration_utils import ConfigMixin
from .schedulers_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule
class GaussianDDPMScheduler(SchedulerMixin, ConfigMixin):
class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,

View File

@@ -19,7 +19,7 @@ import unittest
import torch
from diffusers import DDIM, DDPM, DDIMScheduler, GaussianDDPMScheduler, LatentDiffusion, UNetModel
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
@@ -107,7 +107,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
schedular = GaussianDDPMScheduler(timesteps=10)
schedular = DDPMScheduler(timesteps=10)
ddpm = DDPM(model, schedular)
@@ -147,7 +147,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id)
noise_scheduler = GaussianDDPMScheduler.from_config(model_id)
noise_scheduler = DDPMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt")
ddpm = DDPM(unet=unet, noise_scheduler=noise_scheduler)

View File

@@ -20,7 +20,7 @@ import unittest
import numpy as np
import torch
from diffusers import DDIMScheduler, GaussianDDPMScheduler
from diffusers import DDIMScheduler, DDPMScheduler
torch.backends.cuda.matmul.allow_tf32 = False
@@ -163,7 +163,7 @@ class SchedulerCommonTest(unittest.TestCase):
class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (GaussianDDPMScheduler,)
scheduler_classes = (DDPMScheduler,)
def get_scheduler_config(self, **kwargs):
config = {