mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
VQ-diffusion (#658)
* Changes for VQ-diffusion VQVAE Add specify dimension of embeddings to VQModel: `VQModel` will by default set the dimension of embeddings to the number of latent channels. The VQ-diffusion VQVAE has a smaller embedding dimension, 128, than number of latent channels, 256. Add AttnDownEncoderBlock2D and AttnUpDecoderBlock2D to the up and down unet block helpers. VQ-diffusion's VQVAE uses those two block types. * Changes for VQ-diffusion transformer Modify attention.py so SpatialTransformer can be used for VQ-diffusion's transformer. SpatialTransformer: - Can now operate over discrete inputs (classes of vector embeddings) as well as continuous. - `in_channels` was made optional in the constructor so two locations where it was passed as a positional arg were moved to kwargs - modified forward pass to take optional timestep embeddings ImagePositionalEmbeddings: - added to provide positional embeddings to discrete inputs for latent pixels BasicTransformerBlock: - norm layers were made configurable so that the VQ-diffusion could use AdaLayerNorm with timestep embeddings - modified forward pass to take optional timestep embeddings CrossAttention: - now may optionally take a bias parameter for its query, key, and value linear layers FeedForward: - Internal layers are now configurable ApproximateGELU: - Activation function in VQ-diffusion's feedforward layer AdaLayerNorm: - Norm layer modified to incorporate timestep embeddings * Add VQ-diffusion scheduler * Add VQ-diffusion pipeline * Add VQ-diffusion convert script to diffusers * Add VQ-diffusion dummy objects * Add VQ-diffusion markdown docs * Add VQ-diffusion tests * some renaming * some fixes * more renaming * correct * fix typo * correct weights * finalize * fix tests * Apply suggestions from code review Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * finish * finish * up Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
0
tests/pipelines/vq_diffusion/__init__.py
Normal file
0
tests/pipelines/vq_diffusion/__init__.py
Normal file
175
tests/pipelines/vq_diffusion/test_vq_diffusion.py
Normal file
175
tests/pipelines/vq_diffusion/test_vq_diffusion.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
||||
from diffusers.utils import load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def num_embed(self):
|
||||
return 12
|
||||
|
||||
@property
|
||||
def num_embeds_ada_norm(self):
|
||||
return 12
|
||||
|
||||
@property
|
||||
def dummy_vqvae(self):
|
||||
torch.manual_seed(0)
|
||||
model = VQModel(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=3,
|
||||
num_vq_embeddings=self.num_embed,
|
||||
vq_embed_dim=3,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
@property
|
||||
def dummy_transformer(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
height = 12
|
||||
width = 12
|
||||
|
||||
model_kwargs = {
|
||||
"attention_bias": True,
|
||||
"cross_attention_dim": 32,
|
||||
"attention_head_dim": height * width,
|
||||
"num_attention_heads": 1,
|
||||
"num_vector_embeds": self.num_embed,
|
||||
"num_embeds_ada_norm": self.num_embeds_ada_norm,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": width,
|
||||
"activation_fn": "geglu-approximate",
|
||||
}
|
||||
|
||||
model = Transformer2DModel(**model_kwargs)
|
||||
return model
|
||||
|
||||
def test_vq_diffusion(self):
|
||||
device = "cpu"
|
||||
|
||||
vqvae = self.dummy_vqvae
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
transformer = self.dummy_transformer
|
||||
scheduler = VQDiffusionScheduler(self.num_embed)
|
||||
|
||||
pipe = VQDiffusionPipeline(
|
||||
vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "teddy bear playing in the pool"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = pipe(
|
||||
[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 24, 24, 3)
|
||||
|
||||
expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_vq_diffusion(self):
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/vq_diffusion/teddy_bear_pool.png"
|
||||
)
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq")
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipeline(
|
||||
"teddy bear playing in the pool",
|
||||
truncation_rate=0.86,
|
||||
num_images_per_prompt=1,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (256, 256, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
@@ -18,8 +18,9 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.models.attention import AttentionBlock, SpatialTransformer
|
||||
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.resnet import Downsample2D, Upsample2D
|
||||
from diffusers.utils import torch_device
|
||||
@@ -235,7 +236,7 @@ class AttentionBlockTests(unittest.TestCase):
|
||||
num_head_channels=1,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
norm_num_groups=32,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
attention_scores = attentionBlock(sample)
|
||||
@@ -259,7 +260,7 @@ class AttentionBlockTests(unittest.TestCase):
|
||||
channels=512,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
norm_num_groups=32,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
attention_scores = attentionBlock(sample)
|
||||
@@ -273,22 +274,22 @@ class AttentionBlockTests(unittest.TestCase):
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
class SpatialTransformerTests(unittest.TestCase):
|
||||
class Transformer2DModelTests(unittest.TestCase):
|
||||
def test_spatial_transformer_default(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = SpatialTransformer(
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
in_channels=32,
|
||||
n_heads=1,
|
||||
d_head=32,
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
cross_attention_dim=None,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
attention_scores = spatial_transformer_block(sample)
|
||||
attention_scores = spatial_transformer_block(sample).sample
|
||||
|
||||
assert attention_scores.shape == (1, 32, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
@@ -298,22 +299,22 @@ class SpatialTransformerTests(unittest.TestCase):
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_context_dim(self):
|
||||
def test_spatial_transformer_cross_attention_dim(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
sample = torch.randn(1, 64, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = SpatialTransformer(
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
in_channels=64,
|
||||
n_heads=2,
|
||||
d_head=32,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=32,
|
||||
dropout=0.0,
|
||||
context_dim=64,
|
||||
cross_attention_dim=64,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
context = torch.randn(1, 4, 64).to(torch_device)
|
||||
attention_scores = spatial_transformer_block(sample, context)
|
||||
attention_scores = spatial_transformer_block(sample, context).sample
|
||||
|
||||
assert attention_scores.shape == (1, 64, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
@@ -323,6 +324,44 @@ class SpatialTransformerTests(unittest.TestCase):
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_timestep(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_embeds_ada_norm = 5
|
||||
|
||||
sample = torch.randn(1, 64, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
in_channels=64,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=32,
|
||||
dropout=0.0,
|
||||
cross_attention_dim=64,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device)
|
||||
timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device)
|
||||
attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample
|
||||
attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample
|
||||
|
||||
assert attention_scores_1.shape == (1, 64, 64, 64)
|
||||
assert attention_scores_2.shape == (1, 64, 64, 64)
|
||||
|
||||
output_slice_1 = attention_scores_1[0, -1, -3:, -3:]
|
||||
output_slice_2 = attention_scores_2[0, -1, -3:, -3:]
|
||||
|
||||
expected_slice_1 = torch.tensor(
|
||||
[-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device
|
||||
)
|
||||
expected_slice_2 = torch.tensor(
|
||||
[-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device
|
||||
)
|
||||
|
||||
assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3)
|
||||
assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_dropout(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
@@ -330,18 +369,18 @@ class SpatialTransformerTests(unittest.TestCase):
|
||||
|
||||
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = (
|
||||
SpatialTransformer(
|
||||
Transformer2DModel(
|
||||
in_channels=32,
|
||||
n_heads=2,
|
||||
d_head=16,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=16,
|
||||
dropout=0.3,
|
||||
context_dim=None,
|
||||
cross_attention_dim=None,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
with torch.no_grad():
|
||||
attention_scores = spatial_transformer_block(sample)
|
||||
attention_scores = spatial_transformer_block(sample).sample
|
||||
|
||||
assert attention_scores.shape == (1, 32, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
@@ -350,3 +389,107 @@ class SpatialTransformerTests(unittest.TestCase):
|
||||
[-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
|
||||
def test_spatial_transformer_discrete(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_embed = 5
|
||||
|
||||
sample = torch.randint(0, num_embed, (1, 32)).to(torch_device)
|
||||
spatial_transformer_block = (
|
||||
Transformer2DModel(
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
num_vector_embeds=num_embed,
|
||||
sample_size=16,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
attention_scores = spatial_transformer_block(sample).sample
|
||||
|
||||
assert attention_scores.shape == (1, num_embed - 1, 32)
|
||||
|
||||
output_slice = attention_scores[0, -2:, -3:]
|
||||
|
||||
expected_slice = torch.tensor([-0.8957, -1.8370, -1.3390, -0.9152, -0.5187, -1.1702], device=torch_device)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_default_norm_layers(self):
|
||||
spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
|
||||
|
||||
def test_spatial_transformer_ada_norm_layers(self):
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
in_channels=32,
|
||||
num_embeds_ada_norm=5,
|
||||
)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
|
||||
|
||||
def test_spatial_transformer_default_ff_layers(self):
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
in_channels=32,
|
||||
)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
|
||||
|
||||
dim = 32
|
||||
inner_dim = 128
|
||||
|
||||
# First dimension change
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
|
||||
# NOTE: inner_dim * 2 because GEGLU
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim * 2
|
||||
|
||||
# Second dimension change
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
|
||||
|
||||
def test_spatial_transformer_geglu_approx_ff_layers(self):
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
in_channels=32,
|
||||
activation_fn="geglu-approximate",
|
||||
)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
|
||||
|
||||
dim = 32
|
||||
inner_dim = 128
|
||||
|
||||
# First dimension change
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim
|
||||
|
||||
# Second dimension change
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
|
||||
|
||||
def test_spatial_transformer_attention_bias(self):
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
num_attention_heads=1, attention_head_dim=32, in_channels=32, attention_bias=True
|
||||
)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None
|
||||
assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None
|
||||
assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
@@ -29,6 +30,7 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from diffusers.utils import torch_device
|
||||
|
||||
@@ -85,12 +87,18 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
time_step = float(time_step)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, time_step)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
@@ -122,12 +130,18 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
time_step = float(time_step)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, time_step)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
@@ -154,15 +168,21 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
timestep = 1
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
timestep = float(timestep)
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
timestep = 1
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
timestep = float(timestep)
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, timestep)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
@@ -200,8 +220,14 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, timestep_0)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -255,8 +281,14 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, timestep)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -284,22 +316,26 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "init_noise_sigma"),
|
||||
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "scale_model_input"),
|
||||
f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`",
|
||||
)
|
||||
|
||||
if scheduler_class != VQDiffusionScheduler:
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "init_noise_sigma"),
|
||||
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "scale_model_input"),
|
||||
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
|
||||
" timestep)`",
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "step"),
|
||||
f"{scheduler_class} does not implement a required class method `step(...)`",
|
||||
)
|
||||
|
||||
sample = self.dummy_sample
|
||||
scaled_sample = scheduler.scale_model_input(sample, 0.0)
|
||||
self.assertEqual(sample.shape, scaled_sample.shape)
|
||||
if scheduler_class != VQDiffusionScheduler:
|
||||
sample = self.dummy_sample
|
||||
scaled_sample = scheduler.scale_model_input(sample, 0.0)
|
||||
self.assertEqual(sample.shape, scaled_sample.shape)
|
||||
|
||||
def test_add_noise_device(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
@@ -1238,3 +1274,53 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 2540529) < 10
|
||||
|
||||
|
||||
class VQDiffusionSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (VQDiffusionScheduler,)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_vec_classes": 4097,
|
||||
"num_train_timesteps": 100,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def dummy_sample(self, num_vec_classes):
|
||||
batch_size = 4
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
sample = torch.randint(0, num_vec_classes, (batch_size, height * width))
|
||||
|
||||
return sample
|
||||
|
||||
@property
|
||||
def dummy_sample_deter(self):
|
||||
assert False
|
||||
|
||||
def dummy_model(self, num_vec_classes):
|
||||
def model(sample, t, *args):
|
||||
batch_size, num_latent_pixels = sample.shape
|
||||
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
|
||||
return_value = F.log_softmax(logits.double(), dim=1).float()
|
||||
return return_value
|
||||
|
||||
return model
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [2, 5, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_num_vec_classes(self):
|
||||
for num_vec_classes in [5, 100, 1000, 4000]:
|
||||
self.check_over_configs(num_vec_classes=num_vec_classes)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [0, 50, 99]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_add_noise_device(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user