mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Big Model Renaming (#109)
* up * change model name * renaming * more changes * up * up * up * save checkpoint * finish api / naming * finish config renaming * rename all weights * finish really
This commit is contained in:
committed by
GitHub
parent
13e37cabe0
commit
9c3820d05a
@@ -84,7 +84,7 @@ For more examples see [schedulers](https://github.com/huggingface/diffusers/tree
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import UNetUnconditionalModel, DDIMScheduler
|
||||
from diffusers import UNet2DModel, DDIMScheduler
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import tqdm
|
||||
@@ -93,7 +93,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 1. Load models
|
||||
scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt")
|
||||
unet = UNetUnconditionalModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
|
||||
unet = UNet2DModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
|
||||
|
||||
# 2. Sample gaussian noise
|
||||
generator = torch.manual_seed(23)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline
|
||||
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
@@ -80,7 +80,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
||||
continue
|
||||
|
||||
new_path = new_path.replace('down.', 'downsample_blocks.')
|
||||
new_path = new_path.replace('up.', 'upsample_blocks.')
|
||||
new_path = new_path.replace('up.', 'up_blocks.')
|
||||
|
||||
if additional_replacements is not None:
|
||||
for replacement in additional_replacements:
|
||||
@@ -114,8 +114,8 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
|
||||
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}
|
||||
|
||||
num_upsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
|
||||
upsample_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_upsample_blocks)}
|
||||
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
|
||||
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
||||
|
||||
for i in range(num_downsample_blocks):
|
||||
block_id = (i - 1) // (config['num_res_blocks'] + 1)
|
||||
@@ -164,34 +164,34 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
|
||||
])
|
||||
|
||||
for i in range(num_upsample_blocks):
|
||||
block_id = num_upsample_blocks - 1 - i
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
|
||||
if any('upsample' in layer for layer in upsample_blocks[i]):
|
||||
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
|
||||
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
|
||||
if any('upsample' in layer for layer in up_blocks[i]):
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
|
||||
|
||||
if any('block' in layer for layer in upsample_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in upsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any('block' in layer for layer in up_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_blocks > 0:
|
||||
for j in range(config['num_res_blocks'] + 1):
|
||||
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
paths = renew_resnet_paths(blocks[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
|
||||
if any('attn' in layer for layer in upsample_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any('attn' in layer for layer in up_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_attn > 0:
|
||||
for j in range(config['num_res_blocks'] + 1):
|
||||
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
paths = renew_attention_paths(attns[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
|
||||
new_checkpoint = {k.replace('mid_new_2', 'mid'): v for k, v in new_checkpoint.items()}
|
||||
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
@@ -225,7 +225,7 @@ if __name__ == "__main__":
|
||||
if "ddpm" in config:
|
||||
del config["ddpm"]
|
||||
|
||||
model = UNetUnconditionalModel(**config)
|
||||
model = UNet2DModel(**config)
|
||||
model.load_state_dict(converted_checkpoint)
|
||||
|
||||
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
from diffusers import VQModel, DDPMScheduler, UNetUnconditionalModel, LatentDiffusionUncondPipeline
|
||||
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
@@ -207,14 +207,14 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
to_split = {
|
||||
'middle_block.1.qkv.bias': {
|
||||
'key': 'mid.attentions.0.key.bias',
|
||||
'query': 'mid.attentions.0.query.bias',
|
||||
'value': 'mid.attentions.0.value.bias',
|
||||
'key': 'mid_block.attentions.0.key.bias',
|
||||
'query': 'mid_block.attentions.0.query.bias',
|
||||
'value': 'mid_block.attentions.0.value.bias',
|
||||
},
|
||||
'middle_block.1.qkv.weight': {
|
||||
'key': 'mid.attentions.0.key.weight',
|
||||
'query': 'mid.attentions.0.query.weight',
|
||||
'value': 'mid.attentions.0.value.weight',
|
||||
'key': 'mid_block.attentions.0.key.weight',
|
||||
'query': 'mid_block.attentions.0.query.weight',
|
||||
'value': 'mid_block.attentions.0.value.weight',
|
||||
},
|
||||
}
|
||||
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
|
||||
@@ -239,13 +239,13 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
resnet_0_paths = renew_resnet_paths(resnets)
|
||||
paths = renew_resnet_paths(resnets)
|
||||
|
||||
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
|
||||
|
||||
if ['conv.weight', 'conv.bias'] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
|
||||
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
|
||||
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
|
||||
|
||||
# Clear attentions as they have been attributed above.
|
||||
if len(attentions) == 2:
|
||||
@@ -255,18 +255,18 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {
|
||||
'old': f'output_blocks.{i}.1',
|
||||
'new': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}'
|
||||
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
|
||||
}
|
||||
to_split = {
|
||||
f'output_blocks.{i}.1.qkv.bias': {
|
||||
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
|
||||
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
|
||||
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
|
||||
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
|
||||
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
|
||||
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
|
||||
},
|
||||
f'output_blocks.{i}.1.qkv.weight': {
|
||||
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
|
||||
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
|
||||
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
|
||||
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
|
||||
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
|
||||
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
|
||||
},
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
@@ -281,7 +281,7 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||
for path in resnet_0_paths:
|
||||
old_path = '.'.join(['output_blocks', str(i), path['old']])
|
||||
new_path = '.'.join(['upsample_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
|
||||
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
|
||||
|
||||
new_checkpoint[new_path] = checkpoint[old_path]
|
||||
|
||||
@@ -319,7 +319,7 @@ if __name__ == "__main__":
|
||||
if "ldm" in config:
|
||||
del config["ldm"]
|
||||
|
||||
model = UNetUnconditionalModel(**config)
|
||||
model = UNet2DModel(**config)
|
||||
model.load_state_dict(converted_checkpoint)
|
||||
|
||||
try:
|
||||
|
||||
@@ -17,16 +17,16 @@
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
from diffusers import UNetUnconditionalModel
|
||||
from diffusers import UNet2DModel
|
||||
|
||||
|
||||
def convert_ncsnpp_checkpoint(checkpoint, config):
|
||||
"""
|
||||
Takes a state dict and the path to
|
||||
"""
|
||||
new_model_architecture = UNetUnconditionalModel(**config)
|
||||
new_model_architecture.time_steps.W.data = checkpoint["all_modules.0.W"].data
|
||||
new_model_architecture.time_steps.weight.data = checkpoint["all_modules.0.W"].data
|
||||
new_model_architecture = UNet2DModel(**config)
|
||||
new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
|
||||
new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
|
||||
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
|
||||
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data
|
||||
|
||||
@@ -92,14 +92,14 @@ def convert_ncsnpp_checkpoint(checkpoint, config):
|
||||
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
|
||||
module_index += 1
|
||||
|
||||
set_resnet_weights(new_model_architecture.mid.resnets[0], checkpoint, module_index)
|
||||
set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
|
||||
module_index += 1
|
||||
set_attention_weights(new_model_architecture.mid.attentions[0], checkpoint, module_index)
|
||||
set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
|
||||
module_index += 1
|
||||
set_resnet_weights(new_model_architecture.mid.resnets[1], checkpoint, module_index)
|
||||
set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
|
||||
module_index += 1
|
||||
|
||||
for i, block in enumerate(new_model_architecture.upsample_blocks):
|
||||
for i, block in enumerate(new_model_architecture.up_blocks):
|
||||
has_attentions = hasattr(block, "attentions")
|
||||
for j in range(len(block.resnets)):
|
||||
set_resnet_weights(block.resnets[j], checkpoint, module_index)
|
||||
@@ -134,7 +134,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt",
|
||||
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the checkpoint to convert.",
|
||||
@@ -171,7 +171,7 @@ if __name__ == "__main__":
|
||||
if "sde" in config:
|
||||
del config["sde"]
|
||||
|
||||
model = UNetUnconditionalModel(**config)
|
||||
model = UNet2DModel(**config)
|
||||
model.load_state_dict(converted_checkpoint)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from huggingface_hub import HfApi
|
||||
from transformers.file_utils import has_file
|
||||
from diffusers import UNetUnconditionalModel
|
||||
from diffusers import UNet2DModel
|
||||
import random
|
||||
import torch
|
||||
api = HfApi()
|
||||
@@ -70,19 +70,22 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0
|
||||
models = api.list_models(filter="diffusers")
|
||||
for mod in models:
|
||||
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
|
||||
|
||||
if mod.modelId == "CompVis/ldm-celebahq-256" or not has_file(mod.modelId, "config.json"):
|
||||
model = UNetUnconditionalModel.from_pretrained(mod.modelId, subfolder = "unet")
|
||||
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
|
||||
|
||||
print(f"Started running {mod.modelId}!!!")
|
||||
|
||||
if mod.modelId.startswith("CompVis"):
|
||||
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
|
||||
else:
|
||||
model = UNetUnconditionalModel.from_pretrained(mod.modelId)
|
||||
model = UNet2DModel.from_pretrained(local_checkpoint)
|
||||
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
with torch.no_grad():
|
||||
logits = model(noise, time_step)['sample']
|
||||
|
||||
torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
|
||||
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
|
||||
print(f"{mod.modelId} has passed succesfully!!!")
|
||||
|
||||
@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
|
||||
__version__ = "0.0.4"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
|
||||
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
|
||||
|
||||
@@ -161,10 +161,10 @@ class ConfigMixin:
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
|
||||
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
|
||||
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
|
||||
" pass `use_auth_token=True`."
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
||||
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
||||
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
|
||||
" login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
WEIGHTS_NAME = "diffusion_model.pt"
|
||||
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -147,7 +147,7 @@ class ModelMixin(torch.nn.Module):
|
||||
models, `pixel_values` for vision models and `input_values` for speech models).
|
||||
"""
|
||||
config_name = CONFIG_NAME
|
||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"]
|
||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -341,7 +341,7 @@ class ModelMixin(torch.nn.Module):
|
||||
subfolder=subfolder,
|
||||
**kwargs,
|
||||
)
|
||||
model.register_to_config(name_or_path=pretrained_model_name_or_path)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# Load model
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
@@ -497,46 +497,45 @@ class ModelMixin(torch.nn.Module):
|
||||
)
|
||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||
|
||||
if False:
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
||||
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
||||
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
||||
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
||||
" identical (initializing a BertForSequenceClassification model from a"
|
||||
" BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
elif len(mismatched_keys) == 0:
|
||||
logger.info(
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
||||
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
||||
" without further training."
|
||||
)
|
||||
if len(mismatched_keys) > 0:
|
||||
mismatched_warning = "\n".join(
|
||||
[
|
||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
||||
" able to use it for predictions and inference."
|
||||
)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
||||
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
||||
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
||||
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
||||
" identical (initializing a BertForSequenceClassification model from a"
|
||||
" BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
elif len(mismatched_keys) == 0:
|
||||
logger.info(
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
||||
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
||||
" without further training."
|
||||
)
|
||||
if len(mismatched_keys) > 0:
|
||||
mismatched_warning = "\n".join(
|
||||
[
|
||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
||||
" able to use it for predictions and inference."
|
||||
)
|
||||
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||
|
||||
|
||||
@@ -16,6 +16,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .unet_conditional import UNetConditionalModel
|
||||
from .unet_unconditional import UNetUnconditionalModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .vae import AutoencoderKL, VQModel
|
||||
|
||||
@@ -17,7 +17,6 @@ class AttentionBlockNew(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=None,
|
||||
num_groups=32,
|
||||
rescale_output_factor=1.0,
|
||||
@@ -25,14 +24,8 @@ class AttentionBlockNew(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels is None:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
|
||||
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
||||
self.num_head_size = num_head_channels
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
||||
|
||||
|
||||
@@ -78,12 +78,11 @@ class Downsample2D(nn.Module):
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.Conv2d_0 = conv
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.op = conv
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
182
src/diffusers/models/unet_2d.py
Normal file
182
src/diffusers/models/unet_2d.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size=None,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
center_input_sample=False,
|
||||
time_embedding_type="positional",
|
||||
freq_shift=0,
|
||||
flip_sin_to_cos=True,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels=(224, 448, 672, 896),
|
||||
layers_per_block=2,
|
||||
mid_block_scale_factor=1,
|
||||
downsample_padding=1,
|
||||
act_fn="silu",
|
||||
attention_head_dim=8,
|
||||
norm_num_groups=32,
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
resnet_groups=norm_num_groups,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "skip_conv"):
|
||||
sample, res_samples, skip_sample = downsample_block(
|
||||
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "skip_conv"):
|
||||
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
||||
else:
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if skip_sample is not None:
|
||||
sample += skip_sample
|
||||
|
||||
if self.config.time_embedding_type == "fourier":
|
||||
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
||||
sample = sample / timesteps
|
||||
|
||||
output = {"sample": sample}
|
||||
|
||||
return output
|
||||
178
src/diffusers/models/unet_2d_condition.py
Normal file
178
src/diffusers/models/unet_2d_condition.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||
|
||||
|
||||
class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size=None,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
center_input_sample=False,
|
||||
flip_sin_to_cos=True,
|
||||
freq_shift=0,
|
||||
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels=(320, 640, 1280, 1280),
|
||||
layers_per_block=2,
|
||||
downsample_padding=1,
|
||||
mid_block_scale_factor=1,
|
||||
act_fn="silu",
|
||||
norm_num_groups=32,
|
||||
norm_eps=1e-5,
|
||||
attention_head_dim=8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
resnet_groups=norm_num_groups,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
# 5. up
|
||||
for upsample_block in self.up_blocks:
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
|
||||
|
||||
# 6. post-process
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
output = {"sample": sample}
|
||||
|
||||
return output
|
||||
@@ -33,8 +33,9 @@ def get_down_block(
|
||||
attn_num_head_channels,
|
||||
downsample_padding=None,
|
||||
):
|
||||
if down_block_type == "UNetResDownBlock2D":
|
||||
return UNetResDownBlock2D(
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock2D":
|
||||
return DownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -44,8 +45,8 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
elif down_block_type == "UNetResAttnDownBlock2D":
|
||||
return UNetResAttnDownBlock2D(
|
||||
elif down_block_type == "AttnDownBlock2D":
|
||||
return AttnDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -56,8 +57,8 @@ def get_down_block(
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif down_block_type == "UNetResCrossAttnDownBlock2D":
|
||||
return UNetResCrossAttnDownBlock2D(
|
||||
elif down_block_type == "CrossAttnDownBlock2D":
|
||||
return CrossAttnDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -68,8 +69,8 @@ def get_down_block(
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif down_block_type == "UNetResSkipDownBlock2D":
|
||||
return UNetResSkipDownBlock2D(
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
return SkipDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -79,8 +80,8 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
elif down_block_type == "UNetResAttnSkipDownBlock2D":
|
||||
return UNetResAttnSkipDownBlock2D(
|
||||
elif down_block_type == "AttnSkipDownBlock2D":
|
||||
return AttnSkipDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -105,8 +106,9 @@ def get_up_block(
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
):
|
||||
if up_block_type == "UNetResUpBlock2D":
|
||||
return UNetResUpBlock2D(
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
return UpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -116,8 +118,8 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
)
|
||||
elif up_block_type == "UNetResCrossAttnUpBlock2D":
|
||||
return UNetResCrossAttnUpBlock2D(
|
||||
elif up_block_type == "CrossAttnUpBlock2D":
|
||||
return CrossAttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -128,8 +130,8 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif up_block_type == "UNetResAttnUpBlock2D":
|
||||
return UNetResAttnUpBlock2D(
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
return AttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -140,8 +142,8 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif up_block_type == "UNetResSkipUpBlock2D":
|
||||
return UNetResSkipUpBlock2D(
|
||||
elif up_block_type == "SkipUpBlock2D":
|
||||
return SkipUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -151,8 +153,8 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
)
|
||||
elif up_block_type == "UNetResAttnSkipUpBlock2D":
|
||||
return UNetResAttnSkipUpBlock2D(
|
||||
elif up_block_type == "AttnSkipUpBlock2D":
|
||||
return AttnSkipUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@@ -322,7 +324,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetResAttnDownBlock2D(nn.Module):
|
||||
class AttnDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -403,7 +405,7 @@ class UNetResAttnDownBlock2D(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UNetResCrossAttnDownBlock2D(nn.Module):
|
||||
class CrossAttnDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -485,7 +487,7 @@ class UNetResCrossAttnDownBlock2D(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UNetResDownBlock2D(nn.Module):
|
||||
class DownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -551,7 +553,7 @@ class UNetResDownBlock2D(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UNetResAttnSkipDownBlock2D(nn.Module):
|
||||
class AttnSkipDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -644,7 +646,7 @@ class UNetResAttnSkipDownBlock2D(nn.Module):
|
||||
return hidden_states, output_states, skip_sample
|
||||
|
||||
|
||||
class UNetResSkipDownBlock2D(nn.Module):
|
||||
class SkipDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -723,7 +725,7 @@ class UNetResSkipDownBlock2D(nn.Module):
|
||||
return hidden_states, output_states, skip_sample
|
||||
|
||||
|
||||
class UNetResAttnUpBlock2D(nn.Module):
|
||||
class AttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -801,7 +803,7 @@ class UNetResAttnUpBlock2D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetResCrossAttnUpBlock2D(nn.Module):
|
||||
class CrossAttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -881,7 +883,7 @@ class UNetResCrossAttnUpBlock2D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetResUpBlock2D(nn.Module):
|
||||
class UpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -944,7 +946,7 @@ class UNetResUpBlock2D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetResAttnSkipUpBlock2D(nn.Module):
|
||||
class AttnSkipUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -1055,7 +1057,7 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
|
||||
return hidden_states, skip_sample
|
||||
|
||||
|
||||
class UNetResSkipUpBlock2D(nn.Module):
|
||||
class SkipUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
||||
@@ -1,213 +0,0 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||
|
||||
|
||||
class UNetConditionalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
|
||||
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
|
||||
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
|
||||
rates at which
|
||||
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
|
||||
downsampling, attention will be used.
|
||||
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
|
||||
conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
|
||||
model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
|
||||
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
|
||||
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
image_size=None,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_res_blocks=2,
|
||||
dropout=0,
|
||||
block_channels=(320, 640, 1280, 1280),
|
||||
down_blocks=(
|
||||
"UNetResCrossAttnDownBlock2D",
|
||||
"UNetResCrossAttnDownBlock2D",
|
||||
"UNetResCrossAttnDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
),
|
||||
downsample_padding=1,
|
||||
up_blocks=(
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResCrossAttnUpBlock2D",
|
||||
"UNetResCrossAttnUpBlock2D",
|
||||
"UNetResCrossAttnUpBlock2D",
|
||||
),
|
||||
resnet_act_fn="silu",
|
||||
resnet_eps=1e-5,
|
||||
conv_resample=True,
|
||||
num_head_channels=8,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0,
|
||||
mid_block_scale_factor=1,
|
||||
center_input_sample=False,
|
||||
resnet_num_groups=30,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
time_embed_dim = block_channels[0] * 4
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
|
||||
timestep_input_dim = block_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
self.downsample_blocks = nn.ModuleList([])
|
||||
self.mid = None
|
||||
self.upsample_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_channels[0]
|
||||
for i, down_block_type in enumerate(down_blocks):
|
||||
input_channel = output_channel
|
||||
output_channel = block_channels[i]
|
||||
is_final_block = i == len(block_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=num_res_blocks,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=num_head_channels,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
self.downsample_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_channels[-1],
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=num_head_channels,
|
||||
resnet_groups=resnet_num_groups,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_channels = list(reversed(block_channels))
|
||||
output_channel = reversed_block_channels[0]
|
||||
for i, up_block_type in enumerate(up_blocks):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_channels[i]
|
||||
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=num_res_blocks + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=num_head_channels,
|
||||
)
|
||||
self.upsample_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=resnet_num_groups, eps=resnet_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
t_emb = self.time_steps(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.downsample_blocks:
|
||||
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
# 5. up
|
||||
for upsample_block in self.upsample_blocks:
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
|
||||
|
||||
# 6. post-process
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
output = {"sample": sample}
|
||||
|
||||
return output
|
||||
@@ -1,212 +0,0 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
|
||||
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
|
||||
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
|
||||
rates at which
|
||||
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
|
||||
downsampling, attention will be used.
|
||||
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
|
||||
conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
|
||||
model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
|
||||
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
|
||||
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
image_size=None,
|
||||
in_channels=None,
|
||||
out_channels=None,
|
||||
num_res_blocks=None,
|
||||
dropout=0,
|
||||
block_channels=(224, 448, 672, 896),
|
||||
down_blocks=(
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
),
|
||||
downsample_padding=1,
|
||||
up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
|
||||
resnet_act_fn="silu",
|
||||
resnet_eps=1e-5,
|
||||
conv_resample=True,
|
||||
num_head_channels=32,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0,
|
||||
time_embedding_type="positional",
|
||||
mid_block_scale_factor=1,
|
||||
center_input_sample=False,
|
||||
resnet_num_groups=32,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
time_embed_dim = block_channels[0] * 4
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=16)
|
||||
timestep_input_dim = 2 * block_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
|
||||
timestep_input_dim = block_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
self.downsample_blocks = nn.ModuleList([])
|
||||
self.mid = None
|
||||
self.upsample_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_channels[0]
|
||||
for i, down_block_type in enumerate(down_blocks):
|
||||
input_channel = output_channel
|
||||
output_channel = block_channels[i]
|
||||
is_final_block = i == len(block_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=num_res_blocks,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=num_head_channels,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
self.downsample_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=block_channels[-1],
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=num_head_channels,
|
||||
resnet_groups=resnet_num_groups,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_channels = list(reversed(block_channels))
|
||||
output_channel = reversed_block_channels[0]
|
||||
for i, up_block_type in enumerate(up_blocks):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_channels[i]
|
||||
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=num_res_blocks + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=num_head_channels,
|
||||
)
|
||||
self.upsample_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
t_emb = self.time_steps(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.downsample_blocks:
|
||||
if hasattr(downsample_block, "skip_conv"):
|
||||
sample, res_samples, skip_sample = downsample_block(
|
||||
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid(sample, emb)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.upsample_blocks:
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "skip_conv"):
|
||||
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
||||
else:
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if skip_sample is not None:
|
||||
sample += skip_sample
|
||||
|
||||
if self.config.time_embedding_type == "fourier":
|
||||
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
||||
sample = sample / timesteps
|
||||
|
||||
output = {"sample": sample}
|
||||
|
||||
return output
|
||||
@@ -25,7 +25,7 @@ from .configuration_utils import ConfigMixin
|
||||
from .utils import DIFFUSERS_CACHE, logging
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_model.pt"
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -28,7 +28,9 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"):
|
||||
def __call__(
|
||||
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -37,7 +39,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
@@ -36,7 +36,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
@@ -52,7 +52,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
latents = latents.to(torch_device)
|
||||
|
||||
@@ -24,7 +24,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
self.vqvae.to(torch_device)
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
latents = latents.to(torch_device)
|
||||
|
||||
@@ -38,7 +38,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
@@ -14,7 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
img_size = self.model.config.image_size
|
||||
img_size = self.model.config.sample_size
|
||||
shape = (1, 3, img_size, img_size)
|
||||
|
||||
model = self.model.to(device)
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import tempfile
|
||||
@@ -23,7 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it
|
||||
from diffusers import UNet2DConditionModel # noqa: F401 TODO(Patrick) - need to write tests with it
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMPipeline,
|
||||
@@ -36,7 +35,7 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
ScoreSdeVePipeline,
|
||||
ScoreSdeVeScheduler,
|
||||
UNetUnconditionalModel,
|
||||
UNet2DModel,
|
||||
VQModel,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -271,7 +270,7 @@ class ModelTesterMixin:
|
||||
|
||||
|
||||
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetUnconditionalModel
|
||||
model_class = UNet2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -294,14 +293,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_channels": (32, 64),
|
||||
"down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
|
||||
"up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
|
||||
"num_head_channels": None,
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
|
||||
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
|
||||
"attention_head_dim": None,
|
||||
"out_channels": 3,
|
||||
"in_channels": 3,
|
||||
"num_res_blocks": 2,
|
||||
"image_size": 32,
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -309,14 +308,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints
|
||||
# def test_output_pretrained(self):
|
||||
# model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
|
||||
# model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
|
||||
# model.eval()
|
||||
#
|
||||
# torch.manual_seed(0)
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.manual_seed_all(0)
|
||||
#
|
||||
# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
# time_step = torch.tensor([10])
|
||||
#
|
||||
# with torch.no_grad():
|
||||
@@ -330,7 +329,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetUnconditionalModel
|
||||
model_class = UNet2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -353,23 +352,23 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"image_size": 32,
|
||||
"sample_size": 32,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_res_blocks": 2,
|
||||
"block_channels": (32, 64),
|
||||
"num_head_channels": 32,
|
||||
"conv_resample": True,
|
||||
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
|
||||
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
|
||||
"layers_per_block": 2,
|
||||
"block_out_channels": (32, 64),
|
||||
"attention_head_dim": 32,
|
||||
"down_block_types": ("DownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "UpBlock2D"),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNetUnconditionalModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True
|
||||
model, loading_info = UNet2DModel.from_pretrained(
|
||||
"/home/patrick/google_checkpoints/unet-ldm-dummy-update", output_loading_info=True
|
||||
)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
@@ -379,14 +378,14 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy-update")
|
||||
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/unet-ldm-dummy-update")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -409,7 +408,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.manual_seed_all(0)
|
||||
#
|
||||
# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
# context = torch.ones((1, 16, 64), dtype=torch.float32)
|
||||
# time_step = torch.tensor([10] * noise.shape[0])
|
||||
#
|
||||
@@ -426,13 +425,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetUnconditionalModel
|
||||
model_class = UNet2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def dummy_input(self, sizes=(32, 32)):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
@@ -449,44 +447,47 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_channels": [32, 64, 64, 64],
|
||||
"block_out_channels": [32, 64, 64, 64],
|
||||
"in_channels": 3,
|
||||
"num_res_blocks": 1,
|
||||
"layers_per_block": 1,
|
||||
"out_channels": 3,
|
||||
"time_embedding_type": "fourier",
|
||||
"resnet_eps": 1e-6,
|
||||
"norm_eps": 1e-6,
|
||||
"mid_block_scale_factor": math.sqrt(2.0),
|
||||
"resnet_num_groups": None,
|
||||
"down_blocks": [
|
||||
"UNetResSkipDownBlock2D",
|
||||
"UNetResAttnSkipDownBlock2D",
|
||||
"UNetResSkipDownBlock2D",
|
||||
"UNetResSkipDownBlock2D",
|
||||
"norm_num_groups": None,
|
||||
"down_block_types": [
|
||||
"SkipDownBlock2D",
|
||||
"AttnSkipDownBlock2D",
|
||||
"SkipDownBlock2D",
|
||||
"SkipDownBlock2D",
|
||||
],
|
||||
"up_blocks": [
|
||||
"UNetResSkipUpBlock2D",
|
||||
"UNetResSkipUpBlock2D",
|
||||
"UNetResAttnSkipUpBlock2D",
|
||||
"UNetResSkipUpBlock2D",
|
||||
"up_block_types": [
|
||||
"SkipUpBlock2D",
|
||||
"SkipUpBlock2D",
|
||||
"AttnSkipUpBlock2D",
|
||||
"SkipUpBlock2D",
|
||||
],
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNetUnconditionalModel.from_pretrained(
|
||||
"fusing/ncsnpp-ffhq-ve-dummy-update", output_loading_info=True
|
||||
model, loading_info = UNet2DModel.from_pretrained(
|
||||
"/home/patrick/google_checkpoints/ncsnpp-celebahq-256", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
inputs = self.dummy_input
|
||||
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
|
||||
inputs["sample"] = noise
|
||||
image = model(**inputs)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained_ve_mid(self):
|
||||
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-celebahq-256")
|
||||
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-celebahq-256")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -511,7 +512,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
def test_output_pretrained_ve_large(self):
|
||||
model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
|
||||
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-ffhq-ve-dummy-update")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -540,10 +541,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = VQModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def dummy_input(self, sizes=(32, 32)):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
@@ -570,7 +570,6 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
"embed_dim": 3,
|
||||
"sane_index_shape": False,
|
||||
"ch_mult": (1,),
|
||||
"dropout": 0.0,
|
||||
"double_z": False,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
@@ -583,7 +582,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
|
||||
model, loading_info = VQModel.from_pretrained(
|
||||
"/home/patrick/google_checkpoints/vqgan-dummy", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
@@ -593,7 +594,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = VQModel.from_pretrained("fusing/vqgan-dummy")
|
||||
model = VQModel.from_pretrained("/home/patrick/google_checkpoints/vqgan-dummy")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -654,7 +655,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
model, loading_info = AutoencoderKL.from_pretrained(
|
||||
"/home/patrick/google_checkpoints/autoencoder-kl-dummy", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
@@ -664,7 +667,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
|
||||
model = AutoencoderKL.from_pretrained("/home/patrick/google_checkpoints/autoencoder-kl-dummy")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -685,14 +688,14 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
model = UNetUnconditionalModel(
|
||||
block_channels=(32, 64),
|
||||
num_res_blocks=2,
|
||||
image_size=32,
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
|
||||
up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
schedular = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
@@ -712,7 +715,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
|
||||
|
||||
ddpm = DDPMPipeline.from_pretrained(model_path)
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
|
||||
@@ -730,7 +733,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
|
||||
|
||||
pipe = DDIMPipeline.from_pretrained(model_path)
|
||||
|
||||
@@ -751,9 +754,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_ddpm_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
|
||||
|
||||
unet = UNetUnconditionalModel.from_pretrained(model_id)
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDPMScheduler.from_config(model_id)
|
||||
scheduler = scheduler.set_format("pt")
|
||||
|
||||
@@ -770,9 +773,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_ddim_lsun(self):
|
||||
model_id = "google/ddpm-ema-bedroom-256"
|
||||
model_id = "/home/patrick/google_checkpoints/ddpm-ema-bedroom-256"
|
||||
|
||||
unet = UNetUnconditionalModel.from_pretrained(model_id)
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
@@ -788,9 +791,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_ddim_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
|
||||
|
||||
unet = UNetUnconditionalModel.from_pretrained(model_id)
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler(tensor_format="pt")
|
||||
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
@@ -806,9 +809,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_pndm_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
|
||||
|
||||
unet = UNetUnconditionalModel.from_pretrained(model_id)
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = PNDMScheduler(tensor_format="pt")
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
@@ -823,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img(self):
|
||||
ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
@@ -839,7 +842,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img_fast(self):
|
||||
ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
@@ -853,13 +856,13 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256")
|
||||
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-church-256")
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
|
||||
scheduler = ScoreSdeVeScheduler.from_config("/home/patrick/google_checkpoints/ncsnpp-church-256")
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
|
||||
|
||||
@@ -874,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_ldm_uncond(self):
|
||||
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
|
||||
ldm = LatentDiffusionUncondPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256")
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
|
||||
|
||||
Reference in New Issue
Block a user