1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Big Model Renaming (#109)

* up

* change model name

* renaming

* more changes

* up

* up

* up

* save checkpoint

* finish api / naming

* finish config renaming

* rename all weights

* finish really
This commit is contained in:
Patrick von Platen
2022-07-21 01:30:45 +02:00
committed by GitHub
parent 13e37cabe0
commit 9c3820d05a
24 changed files with 592 additions and 656 deletions

View File

@@ -84,7 +84,7 @@ For more examples see [schedulers](https://github.com/huggingface/diffusers/tree
```python
import torch
from diffusers import UNetUnconditionalModel, DDIMScheduler
from diffusers import UNet2DModel, DDIMScheduler
import PIL.Image
import numpy as np
import tqdm
@@ -93,7 +93,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load models
scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt")
unet = UNetUnconditionalModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
unet = UNet2DModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
# 2. Sample gaussian noise
generator = torch.manual_seed(23)

View File

@@ -1,4 +1,4 @@
from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
import argparse
import json
import torch
@@ -80,7 +80,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
continue
new_path = new_path.replace('down.', 'downsample_blocks.')
new_path = new_path.replace('up.', 'upsample_blocks.')
new_path = new_path.replace('up.', 'up_blocks.')
if additional_replacements is not None:
for replacement in additional_replacements:
@@ -114,8 +114,8 @@ def convert_ddpm_checkpoint(checkpoint, config):
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}
num_upsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
upsample_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_upsample_blocks)}
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1)
@@ -164,34 +164,34 @@ def convert_ddpm_checkpoint(checkpoint, config):
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
])
for i in range(num_upsample_blocks):
block_id = num_upsample_blocks - 1 - i
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
if any('upsample' in layer for layer in upsample_blocks[i]):
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
if any('block' in layer for layer in upsample_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in upsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0:
for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in upsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0:
for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace('mid_new_2', 'mid'): v for k, v in new_checkpoint.items()}
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
return new_checkpoint
@@ -225,7 +225,7 @@ if __name__ == "__main__":
if "ddpm" in config:
del config["ddpm"]
model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))

View File

@@ -17,7 +17,7 @@
import argparse
import json
import torch
from diffusers import VQModel, DDPMScheduler, UNetUnconditionalModel, LatentDiffusionUncondPipeline
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline
def shave_segments(path, n_shave_prefix_segments=1):
@@ -207,14 +207,14 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths = renew_attention_paths(attentions)
to_split = {
'middle_block.1.qkv.bias': {
'key': 'mid.attentions.0.key.bias',
'query': 'mid.attentions.0.query.bias',
'value': 'mid.attentions.0.value.bias',
'key': 'mid_block.attentions.0.key.bias',
'query': 'mid_block.attentions.0.query.bias',
'value': 'mid_block.attentions.0.value.bias',
},
'middle_block.1.qkv.weight': {
'key': 'mid.attentions.0.key.weight',
'query': 'mid.attentions.0.query.weight',
'value': 'mid.attentions.0.value.weight',
'key': 'mid_block.attentions.0.key.weight',
'query': 'mid_block.attentions.0.query.weight',
'value': 'mid_block.attentions.0.value.weight',
},
}
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
@@ -239,13 +239,13 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
if ['conv.weight', 'conv.bias'] in output_block_list.values():
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
@@ -255,18 +255,18 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attentions)
meta_path = {
'old': f'output_blocks.{i}.1',
'new': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}'
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
}
to_split = {
f'output_blocks.{i}.1.qkv.bias': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
},
f'output_blocks.{i}.1.qkv.weight': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
},
}
assign_to_checkpoint(
@@ -281,7 +281,7 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = '.'.join(['output_blocks', str(i), path['old']])
new_path = '.'.join(['upsample_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
new_checkpoint[new_path] = checkpoint[old_path]
@@ -319,7 +319,7 @@ if __name__ == "__main__":
if "ldm" in config:
del config["ldm"]
model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
try:

View File

@@ -17,16 +17,16 @@
import argparse
import json
import torch
from diffusers import UNetUnconditionalModel
from diffusers import UNet2DModel
def convert_ncsnpp_checkpoint(checkpoint, config):
"""
Takes a state dict and the path to
"""
new_model_architecture = UNetUnconditionalModel(**config)
new_model_architecture.time_steps.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_steps.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture = UNet2DModel(**config)
new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data
@@ -92,14 +92,14 @@ def convert_ncsnpp_checkpoint(checkpoint, config):
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
module_index += 1
set_resnet_weights(new_model_architecture.mid.resnets[0], checkpoint, module_index)
set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
module_index += 1
set_attention_weights(new_model_architecture.mid.attentions[0], checkpoint, module_index)
set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
module_index += 1
set_resnet_weights(new_model_architecture.mid.resnets[1], checkpoint, module_index)
set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
module_index += 1
for i, block in enumerate(new_model_architecture.upsample_blocks):
for i, block in enumerate(new_model_architecture.up_blocks):
has_attentions = hasattr(block, "attentions")
for j in range(len(block.resnets)):
set_resnet_weights(block.resnets[j], checkpoint, module_index)
@@ -134,7 +134,7 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
type=str,
required=False,
help="Path to the checkpoint to convert.",
@@ -171,7 +171,7 @@ if __name__ == "__main__":
if "sde" in config:
del config["sde"]
model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
try:

View File

@@ -1,6 +1,6 @@
from huggingface_hub import HfApi
from transformers.file_utils import has_file
from diffusers import UNetUnconditionalModel
from diffusers import UNet2DModel
import random
import torch
api = HfApi()
@@ -70,19 +70,22 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0
models = api.list_models(filter="diffusers")
for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
if mod.modelId == "CompVis/ldm-celebahq-256" or not has_file(mod.modelId, "config.json"):
model = UNetUnconditionalModel.from_pretrained(mod.modelId, subfolder = "unet")
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
print(f"Started running {mod.modelId}!!!")
if mod.modelId.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
else:
model = UNetUnconditionalModel.from_pretrained(mod.modelId)
model = UNet2DModel.from_pretrained(local_checkpoint)
torch.manual_seed(0)
random.seed(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
logits = model(noise, time_step)['sample']
torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
print(f"{mod.modelId} has passed succesfully!!!")

View File

@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler

View File

@@ -161,10 +161,10 @@ class ConfigMixin:
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
" pass `use_auth_token=True`."
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
" login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(

View File

@@ -34,7 +34,7 @@ from .utils import (
)
WEIGHTS_NAME = "diffusion_model.pt"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
logger = logging.get_logger(__name__)
@@ -147,7 +147,7 @@ class ModelMixin(torch.nn.Module):
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"]
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
def __init__(self):
super().__init__()
@@ -341,7 +341,7 @@ class ModelMixin(torch.nn.Module):
subfolder=subfolder,
**kwargs,
)
model.register_to_config(name_or_path=pretrained_model_name_or_path)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
@@ -497,46 +497,45 @@ class ModelMixin(torch.nn.Module):
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if False:
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
" identical (initializing a BertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
" without further training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
" able to use it for predictions and inference."
)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
" identical (initializing a BertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
" without further training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
" able to use it for predictions and inference."
)
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs

View File

@@ -16,6 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet_conditional import UNetConditionalModel
from .unet_unconditional import UNetUnconditionalModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel

View File

@@ -17,7 +17,6 @@ class AttentionBlockNew(nn.Module):
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
rescale_output_factor=1.0,
@@ -25,14 +24,8 @@ class AttentionBlockNew(nn.Module):
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)

View File

@@ -78,12 +78,11 @@ class Downsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.op = conv
self.conv = conv
def forward(self, x):

View File

@@ -0,0 +1,182 @@
from typing import Dict, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size=None,
in_channels=3,
out_channels=3,
center_input_sample=False,
time_embedding_type="positional",
freq_shift=0,
flip_sin_to_cos=True,
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels=(224, 448, 672, 896),
layers_per_block=2,
mid_block_scale_factor=1,
downsample_padding=1,
act_fn="silu",
attention_head_dim=8,
norm_num_groups=32,
norm_eps=1e-5,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
output = {"sample": sample}
return output

View File

@@ -0,0 +1,178 @@
from typing import Dict, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size=None,
in_channels=4,
out_channels=4,
center_input_sample=False,
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
act_fn="silu",
norm_num_groups=32,
norm_eps=1e-5,
attention_head_dim=8,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
output = {"sample": sample}
return output

View File

@@ -33,8 +33,9 @@ def get_down_block(
attn_num_head_channels,
downsample_padding=None,
):
if down_block_type == "UNetResDownBlock2D":
return UNetResDownBlock2D(
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
return DownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -44,8 +45,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
elif down_block_type == "UNetResAttnDownBlock2D":
return UNetResAttnDownBlock2D(
elif down_block_type == "AttnDownBlock2D":
return AttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -56,8 +57,8 @@ def get_down_block(
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResCrossAttnDownBlock2D":
return UNetResCrossAttnDownBlock2D(
elif down_block_type == "CrossAttnDownBlock2D":
return CrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -68,8 +69,8 @@ def get_down_block(
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResSkipDownBlock2D":
return UNetResSkipDownBlock2D(
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -79,8 +80,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
elif down_block_type == "UNetResAttnSkipDownBlock2D":
return UNetResAttnSkipDownBlock2D(
elif down_block_type == "AttnSkipDownBlock2D":
return AttnSkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -105,8 +106,9 @@ def get_up_block(
resnet_act_fn,
attn_num_head_channels,
):
if up_block_type == "UNetResUpBlock2D":
return UNetResUpBlock2D(
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
return UpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -116,8 +118,8 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResCrossAttnUpBlock2D":
return UNetResCrossAttnUpBlock2D(
elif up_block_type == "CrossAttnUpBlock2D":
return CrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -128,8 +130,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "UNetResAttnUpBlock2D":
return UNetResAttnUpBlock2D(
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -140,8 +142,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "UNetResSkipUpBlock2D":
return UNetResSkipUpBlock2D(
elif up_block_type == "SkipUpBlock2D":
return SkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -151,8 +153,8 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResAttnSkipUpBlock2D":
return UNetResAttnSkipUpBlock2D(
elif up_block_type == "AttnSkipUpBlock2D":
return AttnSkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
@@ -322,7 +324,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
return hidden_states
class UNetResAttnDownBlock2D(nn.Module):
class AttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
@@ -403,7 +405,7 @@ class UNetResAttnDownBlock2D(nn.Module):
return hidden_states, output_states
class UNetResCrossAttnDownBlock2D(nn.Module):
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
@@ -485,7 +487,7 @@ class UNetResCrossAttnDownBlock2D(nn.Module):
return hidden_states, output_states
class UNetResDownBlock2D(nn.Module):
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
@@ -551,7 +553,7 @@ class UNetResDownBlock2D(nn.Module):
return hidden_states, output_states
class UNetResAttnSkipDownBlock2D(nn.Module):
class AttnSkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
@@ -644,7 +646,7 @@ class UNetResAttnSkipDownBlock2D(nn.Module):
return hidden_states, output_states, skip_sample
class UNetResSkipDownBlock2D(nn.Module):
class SkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
@@ -723,7 +725,7 @@ class UNetResSkipDownBlock2D(nn.Module):
return hidden_states, output_states, skip_sample
class UNetResAttnUpBlock2D(nn.Module):
class AttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
@@ -801,7 +803,7 @@ class UNetResAttnUpBlock2D(nn.Module):
return hidden_states
class UNetResCrossAttnUpBlock2D(nn.Module):
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
@@ -881,7 +883,7 @@ class UNetResCrossAttnUpBlock2D(nn.Module):
return hidden_states
class UNetResUpBlock2D(nn.Module):
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
@@ -944,7 +946,7 @@ class UNetResUpBlock2D(nn.Module):
return hidden_states
class UNetResAttnSkipUpBlock2D(nn.Module):
class AttnSkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
@@ -1055,7 +1057,7 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
return hidden_states, skip_sample
class UNetResSkipUpBlock2D(nn.Module):
class SkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,

View File

@@ -1,213 +0,0 @@
from typing import Dict, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
class UNetConditionalModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
@register_to_config
def __init__(
self,
image_size=None,
in_channels=4,
out_channels=4,
num_res_blocks=2,
dropout=0,
block_channels=(320, 640, 1280, 1280),
down_blocks=(
"UNetResCrossAttnDownBlock2D",
"UNetResCrossAttnDownBlock2D",
"UNetResCrossAttnDownBlock2D",
"UNetResDownBlock2D",
),
downsample_padding=1,
up_blocks=(
"UNetResUpBlock2D",
"UNetResCrossAttnUpBlock2D",
"UNetResCrossAttnUpBlock2D",
"UNetResCrossAttnUpBlock2D",
),
resnet_act_fn="silu",
resnet_eps=1e-5,
conv_resample=True,
num_head_channels=8,
flip_sin_to_cos=True,
downscale_freq_shift=0,
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=30,
):
super().__init__()
self.image_size = image_size
time_embed_dim = block_channels[0] * 4
# input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
# time
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
timestep_input_dim = block_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.downsample_blocks = nn.ModuleList([])
self.mid = None
self.upsample_blocks = nn.ModuleList([])
# down
output_channel = block_channels[0]
for i, down_block_type in enumerate(down_blocks):
input_channel = output_channel
output_channel = block_channels[i]
is_final_block = i == len(block_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=num_res_blocks,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
downsample_padding=downsample_padding,
)
self.downsample_blocks.append(down_block)
# mid
self.mid = UNetMidBlock2DCrossAttn(
in_channels=block_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
resnet_groups=resnet_num_groups,
)
# up
reversed_block_channels = list(reversed(block_channels))
output_channel = reversed_block_channels[0]
for i, up_block_type in enumerate(up_blocks):
prev_output_channel = output_channel
output_channel = reversed_block_channels[i]
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
is_final_block = i == len(block_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=num_res_blocks + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
)
self.upsample_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=resnet_num_groups, eps=resnet_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_steps(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up
for upsample_block in self.upsample_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
output = {"sample": sample}
return output

View File

@@ -1,212 +0,0 @@
from typing import Dict, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
@register_to_config
def __init__(
self,
image_size=None,
in_channels=None,
out_channels=None,
num_res_blocks=None,
dropout=0,
block_channels=(224, 448, 672, 896),
down_blocks=(
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
),
downsample_padding=1,
up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
resnet_act_fn="silu",
resnet_eps=1e-5,
conv_resample=True,
num_head_channels=32,
flip_sin_to_cos=True,
downscale_freq_shift=0,
time_embedding_type="positional",
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=32,
):
super().__init__()
self.image_size = image_size
time_embed_dim = block_channels[0] * 4
# input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
# time
if time_embedding_type == "fourier":
self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=16)
timestep_input_dim = 2 * block_channels[0]
elif time_embedding_type == "positional":
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
timestep_input_dim = block_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.downsample_blocks = nn.ModuleList([])
self.mid = None
self.upsample_blocks = nn.ModuleList([])
# down
output_channel = block_channels[0]
for i, down_block_type in enumerate(down_blocks):
input_channel = output_channel
output_channel = block_channels[i]
is_final_block = i == len(block_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=num_res_blocks,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
downsample_padding=downsample_padding,
)
self.downsample_blocks.append(down_block)
# mid
self.mid = UNetMidBlock2D(
in_channels=block_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
resnet_groups=resnet_num_groups,
)
# up
reversed_block_channels = list(reversed(block_channels))
output_channel = reversed_block_channels[0]
for i, up_block_type in enumerate(up_blocks):
prev_output_channel = output_channel
output_channel = reversed_block_channels[i]
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
is_final_block = i == len(block_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=num_res_blocks + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
)
self.upsample_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_steps(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.upsample_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
output = {"sample": sample}
return output

View File

@@ -25,7 +25,7 @@ from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging
INDEX_FILE = "diffusion_model.pt"
INDEX_FILE = "diffusion_pytorch_model.bin"
logger = logging.get_logger(__name__)

View File

@@ -28,7 +28,9 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"):
def __call__(
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -37,7 +39,7 @@ class DDIMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)

View File

@@ -36,7 +36,7 @@ class DDPMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)

View File

@@ -52,7 +52,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)

View File

@@ -24,7 +24,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.vqvae.to(torch_device)
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)

View File

@@ -38,7 +38,7 @@ class PNDMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)

View File

@@ -14,7 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
img_size = self.model.config.sample_size
shape = (1, 3, img_size, img_size)
model = self.model.to(device)

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import tempfile
@@ -23,7 +22,7 @@ import numpy as np
import torch
import PIL
from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it
from diffusers import UNet2DConditionModel # noqa: F401 TODO(Patrick) - need to write tests with it
from diffusers import (
AutoencoderKL,
DDIMPipeline,
@@ -36,7 +35,7 @@ from diffusers import (
PNDMScheduler,
ScoreSdeVePipeline,
ScoreSdeVeScheduler,
UNetUnconditionalModel,
UNet2DModel,
VQModel,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
@@ -271,7 +270,7 @@ class ModelTesterMixin:
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetUnconditionalModel
model_class = UNet2DModel
@property
def dummy_input(self):
@@ -294,14 +293,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_channels": (32, 64),
"down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
"up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
"num_head_channels": None,
"block_out_channels": (32, 64),
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
"attention_head_dim": None,
"out_channels": 3,
"in_channels": 3,
"num_res_blocks": 2,
"image_size": 32,
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -309,14 +308,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints
# def test_output_pretrained(self):
# model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
# model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
# model.eval()
#
# torch.manual_seed(0)
# if torch.cuda.is_available():
# torch.cuda.manual_seed_all(0)
#
# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
# time_step = torch.tensor([10])
#
# with torch.no_grad():
@@ -330,7 +329,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetUnconditionalModel
model_class = UNet2DModel
@property
def dummy_input(self):
@@ -353,23 +352,23 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"image_size": 32,
"sample_size": 32,
"in_channels": 4,
"out_channels": 4,
"num_res_blocks": 2,
"block_channels": (32, 64),
"num_head_channels": 32,
"conv_resample": True,
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
"layers_per_block": 2,
"block_out_channels": (32, 64),
"attention_head_dim": 32,
"down_block_types": ("DownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "UpBlock2D"),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = UNetUnconditionalModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True
model, loading_info = UNet2DModel.from_pretrained(
"/home/patrick/google_checkpoints/unet-ldm-dummy-update", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -379,14 +378,14 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy-update")
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/unet-ldm-dummy-update")
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
@@ -409,7 +408,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
# if torch.cuda.is_available():
# torch.cuda.manual_seed_all(0)
#
# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
# context = torch.ones((1, 16, 64), dtype=torch.float32)
# time_step = torch.tensor([10] * noise.shape[0])
#
@@ -426,13 +425,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetUnconditionalModel
model_class = UNet2DModel
@property
def dummy_input(self):
def dummy_input(self, sizes=(32, 32)):
batch_size = 4
num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device)
@@ -449,44 +447,47 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_channels": [32, 64, 64, 64],
"block_out_channels": [32, 64, 64, 64],
"in_channels": 3,
"num_res_blocks": 1,
"layers_per_block": 1,
"out_channels": 3,
"time_embedding_type": "fourier",
"resnet_eps": 1e-6,
"norm_eps": 1e-6,
"mid_block_scale_factor": math.sqrt(2.0),
"resnet_num_groups": None,
"down_blocks": [
"UNetResSkipDownBlock2D",
"UNetResAttnSkipDownBlock2D",
"UNetResSkipDownBlock2D",
"UNetResSkipDownBlock2D",
"norm_num_groups": None,
"down_block_types": [
"SkipDownBlock2D",
"AttnSkipDownBlock2D",
"SkipDownBlock2D",
"SkipDownBlock2D",
],
"up_blocks": [
"UNetResSkipUpBlock2D",
"UNetResSkipUpBlock2D",
"UNetResAttnSkipUpBlock2D",
"UNetResSkipUpBlock2D",
"up_block_types": [
"SkipUpBlock2D",
"SkipUpBlock2D",
"AttnSkipUpBlock2D",
"SkipUpBlock2D",
],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = UNetUnconditionalModel.from_pretrained(
"fusing/ncsnpp-ffhq-ve-dummy-update", output_loading_info=True
model, loading_info = UNet2DModel.from_pretrained(
"/home/patrick/google_checkpoints/ncsnpp-celebahq-256", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)
inputs = self.dummy_input
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
inputs["sample"] = noise
image = model(**inputs)
assert image is not None, "Make sure output is not None"
def test_output_pretrained_ve_mid(self):
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-celebahq-256")
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-celebahq-256")
model.to(torch_device)
torch.manual_seed(0)
@@ -511,7 +512,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self):
model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-ffhq-ve-dummy-update")
model.to(torch_device)
torch.manual_seed(0)
@@ -540,10 +541,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
model_class = VQModel
@property
def dummy_input(self):
def dummy_input(self, sizes=(32, 32)):
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
@@ -570,7 +570,6 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
"embed_dim": 3,
"sane_index_shape": False,
"ch_mult": (1,),
"dropout": 0.0,
"double_z": False,
}
inputs_dict = self.dummy_input
@@ -583,7 +582,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
pass
def test_from_pretrained_hub(self):
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
model, loading_info = VQModel.from_pretrained(
"/home/patrick/google_checkpoints/vqgan-dummy", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -593,7 +594,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = VQModel.from_pretrained("fusing/vqgan-dummy")
model = VQModel.from_pretrained("/home/patrick/google_checkpoints/vqgan-dummy")
model.eval()
torch.manual_seed(0)
@@ -654,7 +655,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
pass
def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
model, loading_info = AutoencoderKL.from_pretrained(
"/home/patrick/google_checkpoints/autoencoder-kl-dummy", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -664,7 +667,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
model = AutoencoderKL.from_pretrained("/home/patrick/google_checkpoints/autoencoder-kl-dummy")
model.eval()
torch.manual_seed(0)
@@ -685,14 +688,14 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models
model = UNetUnconditionalModel(
block_channels=(32, 64),
num_res_blocks=2,
image_size=32,
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
schedular = DDPMScheduler(num_train_timesteps=10)
@@ -712,7 +715,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_from_pretrained_hub(self):
model_path = "google/ddpm-cifar10-32"
model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
ddpm = DDPMPipeline.from_pretrained(model_path)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
@@ -730,7 +733,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path)
@@ -751,9 +754,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ddpm_cifar10(self):
model_id = "google/ddpm-cifar10-32"
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
unet = UNetUnconditionalModel.from_pretrained(model_id)
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id)
scheduler = scheduler.set_format("pt")
@@ -770,9 +773,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ddim_lsun(self):
model_id = "google/ddpm-ema-bedroom-256"
model_id = "/home/patrick/google_checkpoints/ddpm-ema-bedroom-256"
unet = UNetUnconditionalModel.from_pretrained(model_id)
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler.from_config(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
@@ -788,9 +791,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ddim_cifar10(self):
model_id = "google/ddpm-cifar10-32"
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
unet = UNetUnconditionalModel.from_pretrained(model_id)
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler(tensor_format="pt")
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
@@ -806,9 +809,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_pndm_cifar10(self):
model_id = "google/ddpm-cifar10-32"
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
unet = UNetUnconditionalModel.from_pretrained(model_id)
unet = UNet2DModel.from_pretrained(model_id)
scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
@@ -823,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_text2img(self):
ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
@@ -839,7 +842,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_text2img_fast(self):
ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
@@ -853,13 +856,13 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256")
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-church-256")
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
scheduler = ScoreSdeVeScheduler.from_config("/home/patrick/google_checkpoints/ncsnpp-church-256")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
@@ -874,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_uncond(self):
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
ldm = LatentDiffusionUncondPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256")
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]