1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

VQ-diffusion (#658)

* Changes for VQ-diffusion VQVAE

Add specify dimension of embeddings to VQModel:
`VQModel` will by default set the dimension of embeddings to the number
of latent channels. The VQ-diffusion VQVAE has a smaller
embedding dimension, 128, than number of latent channels, 256.

Add AttnDownEncoderBlock2D and AttnUpDecoderBlock2D to the up and down
unet block helpers. VQ-diffusion's VQVAE uses those two block types.

* Changes for VQ-diffusion transformer

Modify attention.py so SpatialTransformer can be used for
VQ-diffusion's transformer.

SpatialTransformer:
- Can now operate over discrete inputs (classes of vector embeddings) as well as continuous.
- `in_channels` was made optional in the constructor so two locations where it was passed as a positional arg were moved to kwargs
- modified forward pass to take optional timestep embeddings

ImagePositionalEmbeddings:
- added to provide positional embeddings to discrete inputs for latent pixels

BasicTransformerBlock:
- norm layers were made configurable so that the VQ-diffusion could use AdaLayerNorm with timestep embeddings
- modified forward pass to take optional timestep embeddings

CrossAttention:
- now may optionally take a bias parameter for its query, key, and value linear layers

FeedForward:
- Internal layers are now configurable

ApproximateGELU:
- Activation function in VQ-diffusion's feedforward layer

AdaLayerNorm:
- Norm layer modified to incorporate timestep embeddings

* Add VQ-diffusion scheduler

* Add VQ-diffusion pipeline

* Add VQ-diffusion convert script to diffusers

* Add VQ-diffusion dummy objects

* Add VQ-diffusion markdown docs

* Add VQ-diffusion tests

* some renaming

* some fixes

* more renaming

* correct

* fix typo

* correct weights

* finalize

* fix tests

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* finish

* finish

* up

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Will Berman
2022-11-03 08:10:28 -07:00
committed by GitHub
parent 269109dbfb
commit ef2ea33c3b
25 changed files with 2674 additions and 223 deletions

View File

View File

@@ -0,0 +1,175 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.utils import load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property
def num_embed(self):
return 12
@property
def num_embeds_ada_norm(self):
return 12
@property
def dummy_vqvae(self):
torch.manual_seed(0)
model = VQModel(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=3,
num_vq_embeddings=self.num_embed,
vq_embed_dim=3,
)
return model
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config)
@property
def dummy_transformer(self):
torch.manual_seed(0)
height = 12
width = 12
model_kwargs = {
"attention_bias": True,
"cross_attention_dim": 32,
"attention_head_dim": height * width,
"num_attention_heads": 1,
"num_vector_embeds": self.num_embed,
"num_embeds_ada_norm": self.num_embeds_ada_norm,
"norm_num_groups": 32,
"sample_size": width,
"activation_fn": "geglu-approximate",
}
model = Transformer2DModel(**model_kwargs)
return model
def test_vq_diffusion(self):
device = "cpu"
vqvae = self.dummy_vqvae
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
transformer = self.dummy_transformer
scheduler = VQDiffusionScheduler(self.num_embed)
pipe = VQDiffusionPipeline(
vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler
)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
prompt = "teddy bear playing in the pool"
generator = torch.Generator(device=device).manual_seed(0)
output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = pipe(
[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 24, 24, 3)
expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch_gpu
class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_vq_diffusion(self):
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/vq_diffusion/teddy_bear_pool.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq")
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipeline(
"teddy bear playing in the pool",
truncation_rate=0.86,
num_images_per_prompt=1,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-2

View File

@@ -18,8 +18,9 @@ import unittest
import numpy as np
import torch
from torch import nn
from diffusers.models.attention import AttentionBlock, SpatialTransformer
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, Upsample2D
from diffusers.utils import torch_device
@@ -235,7 +236,7 @@ class AttentionBlockTests(unittest.TestCase):
num_head_channels=1,
rescale_output_factor=1.0,
eps=1e-6,
num_groups=32,
norm_num_groups=32,
).to(torch_device)
with torch.no_grad():
attention_scores = attentionBlock(sample)
@@ -259,7 +260,7 @@ class AttentionBlockTests(unittest.TestCase):
channels=512,
rescale_output_factor=1.0,
eps=1e-6,
num_groups=32,
norm_num_groups=32,
).to(torch_device)
with torch.no_grad():
attention_scores = attentionBlock(sample)
@@ -273,22 +274,22 @@ class AttentionBlockTests(unittest.TestCase):
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class SpatialTransformerTests(unittest.TestCase):
class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_default(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = SpatialTransformer(
spatial_transformer_block = Transformer2DModel(
in_channels=32,
n_heads=1,
d_head=32,
num_attention_heads=1,
attention_head_dim=32,
dropout=0.0,
context_dim=None,
cross_attention_dim=None,
).to(torch_device)
with torch.no_grad():
attention_scores = spatial_transformer_block(sample)
attention_scores = spatial_transformer_block(sample).sample
assert attention_scores.shape == (1, 32, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
@@ -298,22 +299,22 @@ class SpatialTransformerTests(unittest.TestCase):
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_context_dim(self):
def test_spatial_transformer_cross_attention_dim(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 64, 64, 64).to(torch_device)
spatial_transformer_block = SpatialTransformer(
spatial_transformer_block = Transformer2DModel(
in_channels=64,
n_heads=2,
d_head=32,
num_attention_heads=2,
attention_head_dim=32,
dropout=0.0,
context_dim=64,
cross_attention_dim=64,
).to(torch_device)
with torch.no_grad():
context = torch.randn(1, 4, 64).to(torch_device)
attention_scores = spatial_transformer_block(sample, context)
attention_scores = spatial_transformer_block(sample, context).sample
assert attention_scores.shape == (1, 64, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
@@ -323,6 +324,44 @@ class SpatialTransformerTests(unittest.TestCase):
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_timestep(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_embeds_ada_norm = 5
sample = torch.randn(1, 64, 64, 64).to(torch_device)
spatial_transformer_block = Transformer2DModel(
in_channels=64,
num_attention_heads=2,
attention_head_dim=32,
dropout=0.0,
cross_attention_dim=64,
num_embeds_ada_norm=num_embeds_ada_norm,
).to(torch_device)
with torch.no_grad():
timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device)
timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device)
attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample
attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample
assert attention_scores_1.shape == (1, 64, 64, 64)
assert attention_scores_2.shape == (1, 64, 64, 64)
output_slice_1 = attention_scores_1[0, -1, -3:, -3:]
output_slice_2 = attention_scores_2[0, -1, -3:, -3:]
expected_slice_1 = torch.tensor(
[-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device
)
expected_slice_2 = torch.tensor(
[-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device
)
assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3)
assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3)
def test_spatial_transformer_dropout(self):
torch.manual_seed(0)
if torch.cuda.is_available():
@@ -330,18 +369,18 @@ class SpatialTransformerTests(unittest.TestCase):
sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = (
SpatialTransformer(
Transformer2DModel(
in_channels=32,
n_heads=2,
d_head=16,
num_attention_heads=2,
attention_head_dim=16,
dropout=0.3,
context_dim=None,
cross_attention_dim=None,
)
.to(torch_device)
.eval()
)
with torch.no_grad():
attention_scores = spatial_transformer_block(sample)
attention_scores = spatial_transformer_block(sample).sample
assert attention_scores.shape == (1, 32, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
@@ -350,3 +389,107 @@ class SpatialTransformerTests(unittest.TestCase):
[-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
def test_spatial_transformer_discrete(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_embed = 5
sample = torch.randint(0, num_embed, (1, 32)).to(torch_device)
spatial_transformer_block = (
Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
num_vector_embeds=num_embed,
sample_size=16,
)
.to(torch_device)
.eval()
)
with torch.no_grad():
attention_scores = spatial_transformer_block(sample).sample
assert attention_scores.shape == (1, num_embed - 1, 32)
output_slice = attention_scores[0, -2:, -3:]
expected_slice = torch.tensor([-0.8957, -1.8370, -1.3390, -0.9152, -0.5187, -1.1702], device=torch_device)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_default_norm_layers(self):
spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32)
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
def test_spatial_transformer_ada_norm_layers(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
in_channels=32,
num_embeds_ada_norm=5,
)
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
def test_spatial_transformer_default_ff_layers(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
in_channels=32,
)
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
dim = 32
inner_dim = 128
# First dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
# NOTE: inner_dim * 2 because GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim * 2
# Second dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
def test_spatial_transformer_geglu_approx_ff_layers(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
in_channels=32,
activation_fn="geglu-approximate",
)
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
dim = 32
inner_dim = 128
# First dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim
# Second dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
def test_spatial_transformer_attention_bias(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1, attention_head_dim=32, in_channels=32, attention_bias=True
)
assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None
assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None
assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None

View File

@@ -19,6 +19,7 @@ from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import (
DDIMScheduler,
@@ -29,6 +30,7 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
ScoreSdeVeScheduler,
VQDiffusionScheduler,
)
from diffusers.utils import torch_device
@@ -85,12 +87,18 @@ class SchedulerCommonTest(unittest.TestCase):
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
time_step = float(time_step)
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, time_step)
else:
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
@@ -122,12 +130,18 @@ class SchedulerCommonTest(unittest.TestCase):
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
time_step = float(time_step)
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, time_step)
else:
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
@@ -154,15 +168,21 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
timestep = 1
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
timestep = float(timestep)
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timestep = 1
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
timestep = float(timestep)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, timestep)
else:
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
@@ -200,8 +220,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, timestep_0)
else:
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -255,8 +281,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, timestep)
else:
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -284,22 +316,26 @@ class SchedulerCommonTest(unittest.TestCase):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
self.assertTrue(
hasattr(scheduler, "init_noise_sigma"),
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
)
self.assertTrue(
hasattr(scheduler, "scale_model_input"),
f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`",
)
if scheduler_class != VQDiffusionScheduler:
self.assertTrue(
hasattr(scheduler, "init_noise_sigma"),
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
)
self.assertTrue(
hasattr(scheduler, "scale_model_input"),
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
" timestep)`",
)
self.assertTrue(
hasattr(scheduler, "step"),
f"{scheduler_class} does not implement a required class method `step(...)`",
)
sample = self.dummy_sample
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
if scheduler_class != VQDiffusionScheduler:
sample = self.dummy_sample
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
def test_add_noise_device(self):
for scheduler_class in self.scheduler_classes:
@@ -1238,3 +1274,53 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 2540529) < 10
class VQDiffusionSchedulerTest(SchedulerCommonTest):
scheduler_classes = (VQDiffusionScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_vec_classes": 4097,
"num_train_timesteps": 100,
}
config.update(**kwargs)
return config
def dummy_sample(self, num_vec_classes):
batch_size = 4
height = 8
width = 8
sample = torch.randint(0, num_vec_classes, (batch_size, height * width))
return sample
@property
def dummy_sample_deter(self):
assert False
def dummy_model(self, num_vec_classes):
def model(sample, t, *args):
batch_size, num_latent_pixels = sample.shape
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
return_value = F.log_softmax(logits.double(), dim=1).float()
return return_value
return model
def test_timesteps(self):
for timesteps in [2, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_num_vec_classes(self):
for num_vec_classes in [5, 100, 1000, 4000]:
self.check_over_configs(num_vec_classes=num_vec_classes)
def test_time_indices(self):
for t in [0, 50, 99]:
self.check_over_forward(time_step=t)
def test_add_noise_device(self):
pass