1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

make tests pass

This commit is contained in:
Patrick von Platen
2022-06-12 17:59:39 +00:00
parent 08e7f4b063
commit e83ff11f57
8 changed files with 89 additions and 235 deletions

View File

@@ -13,3 +13,4 @@ from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.ddim import DDIMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion

View File

@@ -1 +1,4 @@
from pipeline_dd
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_glide import GLIDE

View File

@@ -17,7 +17,7 @@
import torch
import tqdm
from .. import DiffusionPipeline
from ..pipeline_utils import DiffusionPipeline
class DDPM(DiffusionPipeline):

View File

@@ -24,13 +24,10 @@ import torch.utils.checkpoint
from torch import nn
import tqdm
from .. import (
ClassifierFreeGuidanceScheduler,
DiffusionPipeline,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from ..pipeline_utils import DiffusionPipeline
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling

View File

@@ -6,7 +6,9 @@ import tqdm
import torch
import torch.nn as nn
from .. import DiffusionPipeline, ConfigMixin, ModelMixin
from ..pipeline_utils import DiffusionPipeline
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
def get_timestep_embedding(timesteps, embedding_dim):

View File

@@ -0,0 +1,54 @@
import os
import random
import unittest
import torch
from distutils.util import strtobool
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# KEY isn't set, default to `default`.
_value = default
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError(f"If set, {key} must be yes or no.")
return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
def slow(test_case):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)

View File

@@ -14,71 +14,21 @@
# limitations under the License.
import os
import random
import tempfile
import unittest
from distutils.util import strtobool
import torch
from diffusers import GaussianDDPMScheduler, UNetModel, DDIMScheduler
from diffusers import DDIM, DDPM, LatentDiffusion
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from models.vision.ddim.modeling_ddim import DDIM
from models.vision.ddpm.modeling_ddpm import DDPM
from models.vision.latent_diffusion.modeling_latent_diffusion import LatentDiffusion
from diffusers.testing_utils import floats_tensor, torch_device, slow
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = False
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# KEY isn't set, default to `default`.
_value = default
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError(f"If set, {key} must be yes or no.")
return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
def slow(test_case):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self):
with self.assertRaises(ValueError):
@@ -124,7 +74,7 @@ class ModelTesterMixin(unittest.TestCase):
num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10])
return (noise, time_step)
@@ -151,116 +101,6 @@ class ModelTesterMixin(unittest.TestCase):
assert image is not None, "Make sure output is not None"
class SamplerTesterMixin(unittest.TestCase):
@slow
def test_sample(self):
generator = torch.manual_seed(0)
# 1. Load models
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
image = scheduler.sample_noise(
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
)
# 3. Denoise
for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (
(1 - scheduler.get_alpha_prod(t - 1))
* torch.sqrt(scheduler.get_alpha(t))
/ (1 - scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
with torch.no_grad():
noise_residual = model(image, t)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
# Note: The better test is to simply check with the following lines of code that the image is sensible
# import PIL
# import numpy as np
# image_processed = image.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("test.png")
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
expected_slice = torch.tensor(
[-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_sample_fast(self):
# 1. Load models
generator = torch.manual_seed(0)
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10)
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
image = scheduler.sample_noise(
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
)
# 3. Denoise
for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (
(1 - scheduler.get_alpha_prod(t - 1))
* torch.sqrt(scheduler.get_alpha(t))
/ (1 - scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
with torch.no_grad():
noise_residual = model(image, t)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
expected_slice = torch.tensor([-0.0304, -0.1895, -0.2436, -0.9837, -0.5422, 0.1931, -0.8175, 0.0862, -0.7783])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models

View File

@@ -14,72 +14,17 @@
# limitations under the License.
import os
import random
import tempfile
import unittest
import numpy as np
from distutils.util import strtobool
import torch
import numpy as np
import unittest
import tempfile
from diffusers import GaussianDDPMScheduler, DDIMScheduler
from diffusers import GaussianDDPMScheduler, UNetModel, DDIMScheduler
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from models.vision.ddim.modeling_ddim import DDIM
from models.vision.ddpm.modeling_ddpm import DDPM
from models.vision.latent_diffusion.modeling_latent_diffusion import LatentDiffusion
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = False
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# KEY isn't set, default to `default`.
_value = default
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError(f"If set, {key} must be yes or no.")
return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
def slow(test_case):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return np.random.randn(data=values, dtype=torch.float).view(shape).contiguous()
class SchedulerCommonTest(unittest.TestCase):
scheduler_class = None
@@ -106,7 +51,6 @@ class SchedulerCommonTest(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
image = self.dummy_image
residual = 0.1 * image
scheduler_config = self.get_scheduler_config()
@@ -120,3 +64,16 @@ class SchedulerCommonTest(unittest.TestCase):
new_output = new_scheduler(residual, image, 1)
import ipdb; ipdb.set_trace()
def test_step(self):
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_class(scheduler_config())
image = self.dummy_image
residual = 0.1 * image
output_0 = scheduler(residual, image, 0)
output_1 = scheduler(residual, image, 1)
self.assertEqual(output_0.shape, image.shape)
self.assertEqual(output_0.shape, output_1.shape)