diff --git a/README.md b/README.md index 6f8d8a6eea..86000dd018 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ For more examples see [schedulers](https://github.com/huggingface/diffusers/tree ```python import torch -from diffusers import UNetUnconditionalModel, DDIMScheduler +from diffusers import UNet2DModel, DDIMScheduler import PIL.Image import numpy as np import tqdm @@ -93,7 +93,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu" # 1. Load models scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt") -unet = UNetUnconditionalModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device) +unet = UNet2DModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device) # 2. Sample gaussian noise generator = torch.manual_seed(23) diff --git a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py index b1499e285d..216018c6a8 100644 --- a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py @@ -1,4 +1,4 @@ -from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline +from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline import argparse import json import torch @@ -80,7 +80,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s continue new_path = new_path.replace('down.', 'downsample_blocks.') - new_path = new_path.replace('up.', 'upsample_blocks.') + new_path = new_path.replace('up.', 'up_blocks.') if additional_replacements is not None: for replacement in additional_replacements: @@ -114,8 +114,8 @@ def convert_ddpm_checkpoint(checkpoint, config): num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer}) downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)} - num_upsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer}) - upsample_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_upsample_blocks)} + num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer}) + up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)} for i in range(num_downsample_blocks): block_id = (i - 1) // (config['num_res_blocks'] + 1) @@ -164,34 +164,34 @@ def convert_ddpm_checkpoint(checkpoint, config): {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'} ]) - for i in range(num_upsample_blocks): - block_id = num_upsample_blocks - 1 - i + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i - if any('upsample' in layer for layer in upsample_blocks[i]): - new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight'] - new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias'] + if any('upsample' in layer for layer in up_blocks[i]): + new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight'] + new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias'] - if any('block' in layer for layer in upsample_blocks[i]): - num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'block' in layer}) - blocks = {layer_id: [key for key in upsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} + if any('block' in layer for layer in up_blocks[i]): + num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer}) + blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} if num_blocks > 0: for j in range(config['num_res_blocks'] + 1): - replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'} + replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} paths = renew_resnet_paths(blocks[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) - if any('attn' in layer for layer in upsample_blocks[i]): - num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer}) - attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} + if any('attn' in layer for layer in up_blocks[i]): + num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer}) + attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} if num_attn > 0: for j in range(config['num_res_blocks'] + 1): - replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'} + replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} paths = renew_attention_paths(attns[j]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) - new_checkpoint = {k.replace('mid_new_2', 'mid'): v for k, v in new_checkpoint.items()} + new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()} return new_checkpoint @@ -225,7 +225,7 @@ if __name__ == "__main__": if "ddpm" in config: del config["ddpm"] - model = UNetUnconditionalModel(**config) + model = UNet2DModel(**config) model.load_state_dict(converted_checkpoint) scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) diff --git a/scripts/convert_ldm_original_checkpoint_to_diffusers.py b/scripts/convert_ldm_original_checkpoint_to_diffusers.py index 2ec816f08c..3116bb2754 100644 --- a/scripts/convert_ldm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ldm_original_checkpoint_to_diffusers.py @@ -17,7 +17,7 @@ import argparse import json import torch -from diffusers import VQModel, DDPMScheduler, UNetUnconditionalModel, LatentDiffusionUncondPipeline +from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline def shave_segments(path, n_shave_prefix_segments=1): @@ -207,14 +207,14 @@ def convert_ldm_checkpoint(checkpoint, config): attentions_paths = renew_attention_paths(attentions) to_split = { 'middle_block.1.qkv.bias': { - 'key': 'mid.attentions.0.key.bias', - 'query': 'mid.attentions.0.query.bias', - 'value': 'mid.attentions.0.value.bias', + 'key': 'mid_block.attentions.0.key.bias', + 'query': 'mid_block.attentions.0.query.bias', + 'value': 'mid_block.attentions.0.value.bias', }, 'middle_block.1.qkv.weight': { - 'key': 'mid.attentions.0.key.weight', - 'query': 'mid.attentions.0.query.weight', - 'value': 'mid.attentions.0.value.weight', + 'key': 'mid_block.attentions.0.key.weight', + 'query': 'mid_block.attentions.0.query.weight', + 'value': 'mid_block.attentions.0.value.weight', }, } assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config) @@ -239,13 +239,13 @@ def convert_ldm_checkpoint(checkpoint, config): resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) - meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'} + meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'} assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config) if ['conv.weight', 'conv.bias'] in output_block_list.values(): index = list(output_block_list.values()).index(['conv.weight', 'conv.bias']) - new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight'] - new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias'] + new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight'] + new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias'] # Clear attentions as they have been attributed above. if len(attentions) == 2: @@ -255,18 +255,18 @@ def convert_ldm_checkpoint(checkpoint, config): paths = renew_attention_paths(attentions) meta_path = { 'old': f'output_blocks.{i}.1', - 'new': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}' + 'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}' } to_split = { f'output_blocks.{i}.1.qkv.bias': { - 'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias', - 'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias', - 'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias', + 'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias', + 'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias', + 'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias', }, f'output_blocks.{i}.1.qkv.weight': { - 'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight', - 'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight', - 'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight', + 'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight', + 'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight', + 'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight', }, } assign_to_checkpoint( @@ -281,7 +281,7 @@ def convert_ldm_checkpoint(checkpoint, config): resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: old_path = '.'.join(['output_blocks', str(i), path['old']]) - new_path = '.'.join(['upsample_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']]) + new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']]) new_checkpoint[new_path] = checkpoint[old_path] @@ -319,7 +319,7 @@ if __name__ == "__main__": if "ldm" in config: del config["ldm"] - model = UNetUnconditionalModel(**config) + model = UNet2DModel(**config) model.load_state_dict(converted_checkpoint) try: diff --git a/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py b/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py index a50b780e51..8f02d69154 100644 --- a/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py @@ -17,16 +17,16 @@ import argparse import json import torch -from diffusers import UNetUnconditionalModel +from diffusers import UNet2DModel def convert_ncsnpp_checkpoint(checkpoint, config): """ Takes a state dict and the path to """ - new_model_architecture = UNetUnconditionalModel(**config) - new_model_architecture.time_steps.W.data = checkpoint["all_modules.0.W"].data - new_model_architecture.time_steps.weight.data = checkpoint["all_modules.0.W"].data + new_model_architecture = UNet2DModel(**config) + new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data + new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data @@ -92,14 +92,14 @@ def convert_ncsnpp_checkpoint(checkpoint, config): block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data module_index += 1 - set_resnet_weights(new_model_architecture.mid.resnets[0], checkpoint, module_index) + set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index) module_index += 1 - set_attention_weights(new_model_architecture.mid.attentions[0], checkpoint, module_index) + set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index) module_index += 1 - set_resnet_weights(new_model_architecture.mid.resnets[1], checkpoint, module_index) + set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index) module_index += 1 - for i, block in enumerate(new_model_architecture.upsample_blocks): + for i, block in enumerate(new_model_architecture.up_blocks): has_attentions = hasattr(block, "attentions") for j in range(len(block.resnets)): set_resnet_weights(block.resnets[j], checkpoint, module_index) @@ -134,7 +134,7 @@ if __name__ == "__main__": parser.add_argument( "--checkpoint_path", - default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt", + default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin", type=str, required=False, help="Path to the checkpoint to convert.", @@ -171,7 +171,7 @@ if __name__ == "__main__": if "sde" in config: del config["sde"] - model = UNetUnconditionalModel(**config) + model = UNet2DModel(**config) model.load_state_dict(converted_checkpoint) try: diff --git a/scripts/generate_logits.py b/scripts/generate_logits.py index 352999f16e..4dbe30f7e5 100644 --- a/scripts/generate_logits.py +++ b/scripts/generate_logits.py @@ -1,6 +1,6 @@ from huggingface_hub import HfApi from transformers.file_utils import has_file -from diffusers import UNetUnconditionalModel +from diffusers import UNet2DModel import random import torch api = HfApi() @@ -70,19 +70,22 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0 models = api.list_models(filter="diffusers") for mod in models: if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256": - - if mod.modelId == "CompVis/ldm-celebahq-256" or not has_file(mod.modelId, "config.json"): - model = UNetUnconditionalModel.from_pretrained(mod.modelId, subfolder = "unet") + local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1] + + print(f"Started running {mod.modelId}!!!") + + if mod.modelId.startswith("CompVis"): + model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet") else: - model = UNetUnconditionalModel.from_pretrained(mod.modelId) + model = UNet2DModel.from_pretrained(local_checkpoint) torch.manual_seed(0) random.seed(0) - noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) + noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) time_step = torch.tensor([10] * noise.shape[0]) with torch.no_grad(): logits = model(noise, time_step)['sample'] - torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3) + assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3) print(f"{mod.modelId} has passed succesfully!!!") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6f50467752..e147a91618 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode __version__ = "0.0.4" from .modeling_utils import ModelMixin -from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel +from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from .pipeline_utils import DiffusionPipeline from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index a59a1e7988..71cb9b7315 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -161,10 +161,10 @@ class ConfigMixin: except RepositoryNotFoundError: raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed" - " on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token" - " having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and" - " pass `use_auth_token=True`." + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" + " login` and pass `use_auth_token=True`." ) except RevisionNotFoundError: raise EnvironmentError( diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 4d4bbbdd7b..44a696ca8d 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -34,7 +34,7 @@ from .utils import ( ) -WEIGHTS_NAME = "diffusion_model.pt" +WEIGHTS_NAME = "diffusion_pytorch_model.bin" logger = logging.get_logger(__name__) @@ -147,7 +147,7 @@ class ModelMixin(torch.nn.Module): models, `pixel_values` for vision models and `input_values` for speech models). """ config_name = CONFIG_NAME - _automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"] + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] def __init__(self): super().__init__() @@ -341,7 +341,7 @@ class ModelMixin(torch.nn.Module): subfolder=subfolder, **kwargs, ) - model.register_to_config(name_or_path=pretrained_model_name_or_path) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # Load model pretrained_model_name_or_path = str(pretrained_model_name_or_path) @@ -497,46 +497,45 @@ class ModelMixin(torch.nn.Module): ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - if False: - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" - " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" - " identical (initializing a BertForSequenceClassification model from a" - " BertForSequenceClassification model)." - ) - else: - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" - " TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - elif len(mismatched_keys) == 0: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" - f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" - f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" - " without further training." - ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" - " able to use it for predictions and inference." - ) + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f3b2fe9e82..0c19b49b14 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,6 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .unet_conditional import UNetConditionalModel -from .unet_unconditional import UNetUnconditionalModel +from .unet_2d import UNet2DModel +from .unet_2d_condition import UNet2DConditionModel from .vae import AutoencoderKL, VQModel diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 795eca7f63..dd22cdbb95 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -17,7 +17,6 @@ class AttentionBlockNew(nn.Module): def __init__( self, channels, - num_heads=1, num_head_channels=None, num_groups=32, rescale_output_factor=1.0, @@ -25,14 +24,8 @@ class AttentionBlockNew(nn.Module): ): super().__init__() self.channels = channels - if num_head_channels is None: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 self.num_head_size = num_head_channels self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ade7db825c..a54199c1a2 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -78,12 +78,11 @@ class Downsample2D(nn.Module): # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": - self.conv = conv - elif name == "Conv2d_0": self.Conv2d_0 = conv self.conv = conv + elif name == "Conv2d_0": + self.conv = conv else: - self.op = conv self.conv = conv def forward(self, x): diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py new file mode 100644 index 0000000000..6203d76f25 --- /dev/null +++ b/src/diffusers/models/unet_2d.py @@ -0,0 +1,182 @@ +from typing import Dict, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +class UNet2DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + sample_size=None, + in_channels=3, + out_channels=3, + center_input_sample=False, + time_embedding_type="positional", + freq_shift=0, + flip_sin_to_cos=True, + down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels=(224, 448, 672, 896), + layers_per_block=2, + mid_block_scale_factor=1, + downsample_padding=1, + act_fn="silu", + attention_head_dim=8, + norm_num_groups=32, + norm_eps=1e-5, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward( + self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] + ) -> Dict[str, torch.FloatTensor]: + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + t_emb = self.time_proj(timesteps) + emb = self.time_embedding(t_emb) + + # 2. pre-process + skip_sample = sample + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb) + + # 5. up + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + else: + sample = upsample_block(sample, res_samples, emb) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if skip_sample is not None: + sample += skip_sample + + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) + sample = sample / timesteps + + output = {"sample": sample} + + return output diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py new file mode 100644 index 0000000000..ae82e202bf --- /dev/null +++ b/src/diffusers/models/unet_2d_condition.py @@ -0,0 +1,178 @@ +from typing import Dict, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from .embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block + + +class UNet2DConditionModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + sample_size=None, + in_channels=4, + out_channels=4, + center_input_sample=False, + flip_sin_to_cos=True, + freq_shift=0, + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(320, 640, 1280, 1280), + layers_per_block=2, + downsample_padding=1, + mid_block_scale_factor=1, + act_fn="silu", + norm_num_groups=32, + norm_eps=1e-5, + attention_head_dim=8, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + attn_num_head_channels=attention_head_dim, + downsample_padding=downsample_padding, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + attn_num_head_channels=attention_head_dim, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + attn_num_head_channels=attention_head_dim, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + ) -> Dict[str, torch.FloatTensor]: + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + t_emb = self.time_proj(timesteps) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + + if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for upsample_block in self.up_blocks: + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) + + # 6. post-process + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + output = {"sample": sample} + + return output diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 60ec2f2e06..67082d2409 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -33,8 +33,9 @@ def get_down_block( attn_num_head_channels, downsample_padding=None, ): - if down_block_type == "UNetResDownBlock2D": - return UNetResDownBlock2D( + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -44,8 +45,8 @@ def get_down_block( resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, ) - elif down_block_type == "UNetResAttnDownBlock2D": - return UNetResAttnDownBlock2D( + elif down_block_type == "AttnDownBlock2D": + return AttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -56,8 +57,8 @@ def get_down_block( downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, ) - elif down_block_type == "UNetResCrossAttnDownBlock2D": - return UNetResCrossAttnDownBlock2D( + elif down_block_type == "CrossAttnDownBlock2D": + return CrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -68,8 +69,8 @@ def get_down_block( downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, ) - elif down_block_type == "UNetResSkipDownBlock2D": - return UNetResSkipDownBlock2D( + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -79,8 +80,8 @@ def get_down_block( resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, ) - elif down_block_type == "UNetResAttnSkipDownBlock2D": - return UNetResAttnSkipDownBlock2D( + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -105,8 +106,9 @@ def get_up_block( resnet_act_fn, attn_num_head_channels, ): - if up_block_type == "UNetResUpBlock2D": - return UNetResUpBlock2D( + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -116,8 +118,8 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) - elif up_block_type == "UNetResCrossAttnUpBlock2D": - return UNetResCrossAttnUpBlock2D( + elif up_block_type == "CrossAttnUpBlock2D": + return CrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -128,8 +130,8 @@ def get_up_block( resnet_act_fn=resnet_act_fn, attn_num_head_channels=attn_num_head_channels, ) - elif up_block_type == "UNetResAttnUpBlock2D": - return UNetResAttnUpBlock2D( + elif up_block_type == "AttnUpBlock2D": + return AttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -140,8 +142,8 @@ def get_up_block( resnet_act_fn=resnet_act_fn, attn_num_head_channels=attn_num_head_channels, ) - elif up_block_type == "UNetResSkipUpBlock2D": - return UNetResSkipUpBlock2D( + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -151,8 +153,8 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) - elif up_block_type == "UNetResAttnSkipUpBlock2D": - return UNetResAttnSkipUpBlock2D( + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, @@ -322,7 +324,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): return hidden_states -class UNetResAttnDownBlock2D(nn.Module): +class AttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -403,7 +405,7 @@ class UNetResAttnDownBlock2D(nn.Module): return hidden_states, output_states -class UNetResCrossAttnDownBlock2D(nn.Module): +class CrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -485,7 +487,7 @@ class UNetResCrossAttnDownBlock2D(nn.Module): return hidden_states, output_states -class UNetResDownBlock2D(nn.Module): +class DownBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -551,7 +553,7 @@ class UNetResDownBlock2D(nn.Module): return hidden_states, output_states -class UNetResAttnSkipDownBlock2D(nn.Module): +class AttnSkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -644,7 +646,7 @@ class UNetResAttnSkipDownBlock2D(nn.Module): return hidden_states, output_states, skip_sample -class UNetResSkipDownBlock2D(nn.Module): +class SkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -723,7 +725,7 @@ class UNetResSkipDownBlock2D(nn.Module): return hidden_states, output_states, skip_sample -class UNetResAttnUpBlock2D(nn.Module): +class AttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -801,7 +803,7 @@ class UNetResAttnUpBlock2D(nn.Module): return hidden_states -class UNetResCrossAttnUpBlock2D(nn.Module): +class CrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -881,7 +883,7 @@ class UNetResCrossAttnUpBlock2D(nn.Module): return hidden_states -class UNetResUpBlock2D(nn.Module): +class UpBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -944,7 +946,7 @@ class UNetResUpBlock2D(nn.Module): return hidden_states -class UNetResAttnSkipUpBlock2D(nn.Module): +class AttnSkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -1055,7 +1057,7 @@ class UNetResAttnSkipUpBlock2D(nn.Module): return hidden_states, skip_sample -class UNetResSkipUpBlock2D(nn.Module): +class SkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, diff --git a/src/diffusers/models/unet_conditional.py b/src/diffusers/models/unet_conditional.py deleted file mode 100644 index 293542f587..0000000000 --- a/src/diffusers/models/unet_conditional.py +++ /dev/null @@ -1,213 +0,0 @@ -from typing import Dict, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from .embeddings import TimestepEmbedding, Timesteps -from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block - - -class UNetConditionalModel(ModelMixin, ConfigMixin): - """ - The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param - model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param - num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample - rates at which - attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x - downsampling, attention will be used. - :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param - conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this - model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention - heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks - for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - @register_to_config - def __init__( - self, - image_size=None, - in_channels=4, - out_channels=4, - num_res_blocks=2, - dropout=0, - block_channels=(320, 640, 1280, 1280), - down_blocks=( - "UNetResCrossAttnDownBlock2D", - "UNetResCrossAttnDownBlock2D", - "UNetResCrossAttnDownBlock2D", - "UNetResDownBlock2D", - ), - downsample_padding=1, - up_blocks=( - "UNetResUpBlock2D", - "UNetResCrossAttnUpBlock2D", - "UNetResCrossAttnUpBlock2D", - "UNetResCrossAttnUpBlock2D", - ), - resnet_act_fn="silu", - resnet_eps=1e-5, - conv_resample=True, - num_head_channels=8, - flip_sin_to_cos=True, - downscale_freq_shift=0, - mid_block_scale_factor=1, - center_input_sample=False, - resnet_num_groups=30, - ): - super().__init__() - self.image_size = image_size - time_embed_dim = block_channels[0] * 4 - - # input - self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1)) - - # time - self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift) - timestep_input_dim = block_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - self.downsample_blocks = nn.ModuleList([]) - self.mid = None - self.upsample_blocks = nn.ModuleList([]) - - # down - output_channel = block_channels[0] - for i, down_block_type in enumerate(down_blocks): - input_channel = output_channel - output_channel = block_channels[i] - is_final_block = i == len(block_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=num_res_blocks, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=num_head_channels, - downsample_padding=downsample_padding, - ) - self.downsample_blocks.append(down_block) - - # mid - self.mid = UNetMidBlock2DCrossAttn( - in_channels=block_channels[-1], - dropout=dropout, - temb_channels=time_embed_dim, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - attn_num_head_channels=num_head_channels, - resnet_groups=resnet_num_groups, - ) - - # up - reversed_block_channels = list(reversed(block_channels)) - output_channel = reversed_block_channels[0] - for i, up_block_type in enumerate(up_blocks): - prev_output_channel = output_channel - output_channel = reversed_block_channels[i] - input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)] - - is_final_block = i == len(block_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=num_res_blocks + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=not is_final_block, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=num_head_channels, - ) - self.upsample_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=resnet_num_groups, eps=resnet_eps) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - ) -> Dict[str, torch.FloatTensor]: - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - t_emb = self.time_steps(timesteps) - emb = self.time_embedding(t_emb) - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.downsample_blocks: - - if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: - sample, res_samples = downsample_block( - hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states) - - # 5. up - for upsample_block in self.upsample_blocks: - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) - - # 6. post-process - - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - output = {"sample": sample} - - return output diff --git a/src/diffusers/models/unet_unconditional.py b/src/diffusers/models/unet_unconditional.py deleted file mode 100644 index c809374a6f..0000000000 --- a/src/diffusers/models/unet_unconditional.py +++ /dev/null @@ -1,212 +0,0 @@ -from typing import Dict, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block - - -class UNetUnconditionalModel(ModelMixin, ConfigMixin): - """ - The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param - model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param - num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample - rates at which - attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x - downsampling, attention will be used. - :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param - conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this - model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention - heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks - for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - @register_to_config - def __init__( - self, - image_size=None, - in_channels=None, - out_channels=None, - num_res_blocks=None, - dropout=0, - block_channels=(224, 448, 672, 896), - down_blocks=( - "UNetResDownBlock2D", - "UNetResAttnDownBlock2D", - "UNetResAttnDownBlock2D", - "UNetResAttnDownBlock2D", - ), - downsample_padding=1, - up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"), - resnet_act_fn="silu", - resnet_eps=1e-5, - conv_resample=True, - num_head_channels=32, - flip_sin_to_cos=True, - downscale_freq_shift=0, - time_embedding_type="positional", - mid_block_scale_factor=1, - center_input_sample=False, - resnet_num_groups=32, - ): - super().__init__() - self.image_size = image_size - time_embed_dim = block_channels[0] * 4 - - # input - self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1)) - - # time - if time_embedding_type == "fourier": - self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=16) - timestep_input_dim = 2 * block_channels[0] - elif time_embedding_type == "positional": - self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift) - timestep_input_dim = block_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - self.downsample_blocks = nn.ModuleList([]) - self.mid = None - self.upsample_blocks = nn.ModuleList([]) - - # down - output_channel = block_channels[0] - for i, down_block_type in enumerate(down_blocks): - input_channel = output_channel - output_channel = block_channels[i] - is_final_block = i == len(block_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=num_res_blocks, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=num_head_channels, - downsample_padding=downsample_padding, - ) - self.downsample_blocks.append(down_block) - - # mid - self.mid = UNetMidBlock2D( - in_channels=block_channels[-1], - dropout=dropout, - temb_channels=time_embed_dim, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - attn_num_head_channels=num_head_channels, - resnet_groups=resnet_num_groups, - ) - - # up - reversed_block_channels = list(reversed(block_channels)) - output_channel = reversed_block_channels[0] - for i, up_block_type in enumerate(up_blocks): - prev_output_channel = output_channel - output_channel = reversed_block_channels[i] - input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)] - - is_final_block = i == len(block_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=num_res_blocks + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=not is_final_block, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=num_head_channels, - ) - self.upsample_blocks.append(up_block) - prev_output_channel = output_channel - - # out - num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32) - self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1) - - def forward( - self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] - ) -> Dict[str, torch.FloatTensor]: - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - t_emb = self.time_steps(timesteps) - emb = self.time_embedding(t_emb) - - # 2. pre-process - skip_sample = sample - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.downsample_blocks: - if hasattr(downsample_block, "skip_conv"): - sample, res_samples, skip_sample = downsample_block( - hidden_states=sample, temb=emb, skip_sample=skip_sample - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid(sample, emb) - - # 5. up - skip_sample = None - for upsample_block in self.upsample_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - if hasattr(upsample_block, "skip_conv"): - sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) - else: - sample = upsample_block(sample, res_samples, emb) - - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if skip_sample is not None: - sample += skip_sample - - if self.config.time_embedding_type == "fourier": - timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) - sample = sample / timesteps - - output = {"sample": sample} - - return output diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 0fa6852bd1..ee593b4632 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -25,7 +25,7 @@ from .configuration_utils import ConfigMixin from .utils import DIFFUSERS_CACHE, logging -INDEX_FILE = "diffusion_model.pt" +INDEX_FILE = "diffusion_pytorch_model.bin" logger = logging.get_logger(__name__) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 5f9227c9cb..a1000ae2ef 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -28,7 +28,9 @@ class DDIMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"): + def __call__( + self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil" + ): # eta corresponds to η in paper and should be between [0, 1] if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -37,7 +39,7 @@ class DDIMPipeline(DiffusionPipeline): # Sample gaussian noise to begin loop image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) image = image.to(torch_device) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index a7309224ef..c947827f01 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -36,7 +36,7 @@ class DDPMPipeline(DiffusionPipeline): # Sample gaussian noise to begin loop image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) image = image.to(torch_device) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 5b3c5dc8cb..e6b2090264 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -52,7 +52,7 @@ class LatentDiffusionPipeline(DiffusionPipeline): text_embeddings = self.bert(text_input.input_ids.to(torch_device)) latents = torch.randn( - (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) latents = latents.to(torch_device) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 0964225e8b..5445c44cd7 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -24,7 +24,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): self.vqvae.to(torch_device) latents = torch.randn( - (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) latents = latents.to(torch_device) diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 33ec1a3e98..88e557f967 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -38,7 +38,7 @@ class PNDMPipeline(DiffusionPipeline): # Sample gaussian noise to begin loop image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), generator=generator, ) image = image.to(torch_device) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 5b3be8b66f..ba8fbd762c 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -14,7 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - img_size = self.model.config.image_size + img_size = self.model.config.sample_size shape = (1, 3, img_size, img_size) model = self.model.to(device) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6b8b17128d..5df13f3a5e 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import inspect import math import tempfile @@ -23,7 +22,7 @@ import numpy as np import torch import PIL -from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it +from diffusers import UNet2DConditionModel # noqa: F401 TODO(Patrick) - need to write tests with it from diffusers import ( AutoencoderKL, DDIMPipeline, @@ -36,7 +35,7 @@ from diffusers import ( PNDMScheduler, ScoreSdeVePipeline, ScoreSdeVeScheduler, - UNetUnconditionalModel, + UNet2DModel, VQModel, ) from diffusers.configuration_utils import ConfigMixin, register_to_config @@ -271,7 +270,7 @@ class ModelTesterMixin: class UnetModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetUnconditionalModel + model_class = UNet2DModel @property def dummy_input(self): @@ -294,14 +293,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "block_channels": (32, 64), - "down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), - "up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), - "num_head_channels": None, + "block_out_channels": (32, 64), + "down_block_types": ("DownBlock2D", "AttnDownBlock2D"), + "up_block_types": ("AttnUpBlock2D", "UpBlock2D"), + "attention_head_dim": None, "out_channels": 3, "in_channels": 3, - "num_res_blocks": 2, - "image_size": 32, + "layers_per_block": 2, + "sample_size": 32, } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -309,14 +308,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): # TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints # def test_output_pretrained(self): -# model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet") +# model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet") # model.eval() # # torch.manual_seed(0) # if torch.cuda.is_available(): # torch.cuda.manual_seed_all(0) # -# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) +# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) # time_step = torch.tensor([10]) # # with torch.no_grad(): @@ -330,7 +329,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetUnconditionalModel + model_class = UNet2DModel @property def dummy_input(self): @@ -353,23 +352,23 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "image_size": 32, + "sample_size": 32, "in_channels": 4, "out_channels": 4, - "num_res_blocks": 2, - "block_channels": (32, 64), - "num_head_channels": 32, - "conv_resample": True, - "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"), - "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"), + "layers_per_block": 2, + "block_out_channels": (32, 64), + "attention_head_dim": 32, + "down_block_types": ("DownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "UpBlock2D"), } inputs_dict = self.dummy_input return init_dict, inputs_dict def test_from_pretrained_hub(self): - model, loading_info = UNetUnconditionalModel.from_pretrained( - "fusing/unet-ldm-dummy-update", output_loading_info=True + model, loading_info = UNet2DModel.from_pretrained( + "/home/patrick/google_checkpoints/unet-ldm-dummy-update", output_loading_info=True ) + self.assertIsNotNone(model) self.assertEqual(len(loading_info["missing_keys"]), 0) @@ -379,14 +378,14 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): assert image is not None, "Make sure output is not None" def test_output_pretrained(self): - model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy-update") + model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/unet-ldm-dummy-update") model.eval() torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) + noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) time_step = torch.tensor([10] * noise.shape[0]) with torch.no_grad(): @@ -409,7 +408,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): # if torch.cuda.is_available(): # torch.cuda.manual_seed_all(0) # -# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) +# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) # context = torch.ones((1, 16, 64), dtype=torch.float32) # time_step = torch.tensor([10] * noise.shape[0]) # @@ -426,13 +425,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetUnconditionalModel + model_class = UNet2DModel @property - def dummy_input(self): + def dummy_input(self, sizes=(32, 32)): batch_size = 4 num_channels = 3 - sizes = (32, 32) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) time_step = torch.tensor(batch_size * [10]).to(torch_device) @@ -449,44 +447,47 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "block_channels": [32, 64, 64, 64], + "block_out_channels": [32, 64, 64, 64], "in_channels": 3, - "num_res_blocks": 1, + "layers_per_block": 1, "out_channels": 3, "time_embedding_type": "fourier", - "resnet_eps": 1e-6, + "norm_eps": 1e-6, "mid_block_scale_factor": math.sqrt(2.0), - "resnet_num_groups": None, - "down_blocks": [ - "UNetResSkipDownBlock2D", - "UNetResAttnSkipDownBlock2D", - "UNetResSkipDownBlock2D", - "UNetResSkipDownBlock2D", + "norm_num_groups": None, + "down_block_types": [ + "SkipDownBlock2D", + "AttnSkipDownBlock2D", + "SkipDownBlock2D", + "SkipDownBlock2D", ], - "up_blocks": [ - "UNetResSkipUpBlock2D", - "UNetResSkipUpBlock2D", - "UNetResAttnSkipUpBlock2D", - "UNetResSkipUpBlock2D", + "up_block_types": [ + "SkipUpBlock2D", + "SkipUpBlock2D", + "AttnSkipUpBlock2D", + "SkipUpBlock2D", ], } inputs_dict = self.dummy_input return init_dict, inputs_dict def test_from_pretrained_hub(self): - model, loading_info = UNetUnconditionalModel.from_pretrained( - "fusing/ncsnpp-ffhq-ve-dummy-update", output_loading_info=True + model, loading_info = UNet2DModel.from_pretrained( + "/home/patrick/google_checkpoints/ncsnpp-celebahq-256", output_loading_info=True ) self.assertIsNotNone(model) self.assertEqual(len(loading_info["missing_keys"]), 0) model.to(torch_device) - image = model(**self.dummy_input) + inputs = self.dummy_input + noise = floats_tensor((4, 3) + (256, 256)).to(torch_device) + inputs["sample"] = noise + image = model(**inputs) assert image is not None, "Make sure output is not None" def test_output_pretrained_ve_mid(self): - model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-celebahq-256") + model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-celebahq-256") model.to(torch_device) torch.manual_seed(0) @@ -511,7 +512,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) def test_output_pretrained_ve_large(self): - model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") + model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-ffhq-ve-dummy-update") model.to(torch_device) torch.manual_seed(0) @@ -540,10 +541,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): model_class = VQModel @property - def dummy_input(self): + def dummy_input(self, sizes=(32, 32)): batch_size = 4 num_channels = 3 - sizes = (32, 32) image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) @@ -570,7 +570,6 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): "embed_dim": 3, "sane_index_shape": False, "ch_mult": (1,), - "dropout": 0.0, "double_z": False, } inputs_dict = self.dummy_input @@ -583,7 +582,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): pass def test_from_pretrained_hub(self): - model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) + model, loading_info = VQModel.from_pretrained( + "/home/patrick/google_checkpoints/vqgan-dummy", output_loading_info=True + ) self.assertIsNotNone(model) self.assertEqual(len(loading_info["missing_keys"]), 0) @@ -593,7 +594,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): assert image is not None, "Make sure output is not None" def test_output_pretrained(self): - model = VQModel.from_pretrained("fusing/vqgan-dummy") + model = VQModel.from_pretrained("/home/patrick/google_checkpoints/vqgan-dummy") model.eval() torch.manual_seed(0) @@ -654,7 +655,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): pass def test_from_pretrained_hub(self): - model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) + model, loading_info = AutoencoderKL.from_pretrained( + "/home/patrick/google_checkpoints/autoencoder-kl-dummy", output_loading_info=True + ) self.assertIsNotNone(model) self.assertEqual(len(loading_info["missing_keys"]), 0) @@ -664,7 +667,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): assert image is not None, "Make sure output is not None" def test_output_pretrained(self): - model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") + model = AutoencoderKL.from_pretrained("/home/patrick/google_checkpoints/autoencoder-kl-dummy") model.eval() torch.manual_seed(0) @@ -685,14 +688,14 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_save_pretrained(self): # 1. Load models - model = UNetUnconditionalModel( - block_channels=(32, 64), - num_res_blocks=2, - image_size=32, + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, in_channels=3, out_channels=3, - down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), - up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), ) schedular = DDPMScheduler(num_train_timesteps=10) @@ -712,7 +715,7 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_from_pretrained_hub(self): - model_path = "google/ddpm-cifar10-32" + model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32" ddpm = DDPMPipeline.from_pretrained(model_path) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) @@ -730,7 +733,7 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_output_format(self): - model_path = "google/ddpm-cifar10-32" + model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32" pipe = DDIMPipeline.from_pretrained(model_path) @@ -751,9 +754,9 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_ddpm_cifar10(self): - model_id = "google/ddpm-cifar10-32" + model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32" - unet = UNetUnconditionalModel.from_pretrained(model_id) + unet = UNet2DModel.from_pretrained(model_id) scheduler = DDPMScheduler.from_config(model_id) scheduler = scheduler.set_format("pt") @@ -770,9 +773,9 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_ddim_lsun(self): - model_id = "google/ddpm-ema-bedroom-256" + model_id = "/home/patrick/google_checkpoints/ddpm-ema-bedroom-256" - unet = UNetUnconditionalModel.from_pretrained(model_id) + unet = UNet2DModel.from_pretrained(model_id) scheduler = DDIMScheduler.from_config(model_id) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) @@ -788,9 +791,9 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_ddim_cifar10(self): - model_id = "google/ddpm-cifar10-32" + model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32" - unet = UNetUnconditionalModel.from_pretrained(model_id) + unet = UNet2DModel.from_pretrained(model_id) scheduler = DDIMScheduler(tensor_format="pt") ddim = DDIMPipeline(unet=unet, scheduler=scheduler) @@ -806,9 +809,9 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_pndm_cifar10(self): - model_id = "google/ddpm-cifar10-32" + model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32" - unet = UNetUnconditionalModel.from_pretrained(model_id) + unet = UNet2DModel.from_pretrained(model_id) scheduler = PNDMScheduler(tensor_format="pt") pndm = PNDMPipeline(unet=unet, scheduler=scheduler) @@ -823,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_ldm_text2img(self): - ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256") prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) @@ -839,7 +842,7 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_ldm_text2img_fast(self): - ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256") prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) @@ -853,13 +856,13 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_score_sde_ve_pipeline(self): - model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256") + model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-church-256") torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256") + scheduler = ScoreSdeVeScheduler.from_config("/home/patrick/google_checkpoints/ncsnpp-church-256") sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) @@ -874,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_ldm_uncond(self): - ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256") + ldm = LatentDiffusionUncondPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256") generator = torch.manual_seed(0) image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]