1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add pipeline

This commit is contained in:
Patrick von Platen
2022-06-02 15:55:32 +02:00
parent e83c5363c6
commit 25feac9e65
5 changed files with 170 additions and 2 deletions

22
models/vision/ddpm/example.py Executable file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env python3
from diffusers import UNetModel, GaussianDiffusion
from modeling_ddpm import DDPM
import tempfile
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
# compose Diffusion Pipeline
ddpm = DDPM(unet, sampler)
# generate / sample
image = ddpm()
print(image)
# save and load with 0 extra code (handled by general `DiffusionPipeline` class)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
print("Model saved")
ddpm_new = DDPM.from_pretrained(tmpdirname)
print("Model loaded")
print(ddpm_new)

View File

@@ -0,0 +1,27 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers import DiffusionPipeline
class DDPM(DiffusionPipeline):
def __init__(self, unet, gaussian_sampler):
super().__init__(unet=unet, gaussian_sampler=gaussian_sampler)
def __call__(self, batch_size=1):
image = self.gaussian_sampler.sample(self.unet, batch_size=batch_size)
return image

View File

@@ -6,3 +6,6 @@ __version__ = "0.0.1"
from .models.unet import UNetModel
from .samplers.gaussian import GaussianDiffusion
from .pipeline_utils import DiffusionPipeline
from .modeling_utils import PreTrainedModel

View File

@@ -91,8 +91,8 @@ class Config:
logger.info(f"Configuration saved in {output_config_file}")
@classmethod
def from_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
def get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
@@ -198,6 +198,14 @@ class Config:
f"Values will be initialized to default values."
)
return config_dict, unused_kwargs
@classmethod
def from_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
):
config_dict, unused_kwargs = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
model = cls(**config_dict)
if return_unused_kwargs:

View File

@@ -0,0 +1,108 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Union
import importlib
from .configuration_utils import Config
# CHANGE to diffusers.utils
from transformers.utils import logging
INDEX_FILE = "diffusion_model.pt"
logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
"GaussianDiffusion": ["save_config", "from_config"],
},
"transformers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
},
}
class DiffusionPipeline(Config):
config_name = "model_index.json"
def __init__(self, **kwargs):
for name, module in kwargs.items():
# retrive library
library = module.__module__.split(".")[0]
# retrive class_name
class_name = module.__class__.__name__
# save model index config
self.register(**{name: (library, class_name)})
# set models
setattr(self, name, module)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory)
model_index_dict = self._dict_to_save
model_index_dict.pop("_class_name")
for name, (library_name, class_name) in self._dict_to_save.items():
importable_classes = LOADABLE_CLASSES[library_name]
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
save_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
save_method_name = importable_classes[class_name][0]
save_method = getattr(getattr(self, name), save_method_name)
save_method(os.path.join(save_directory, name))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
# use snapshot download here to get it working from from_pretrained
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path)
init_kwargs = {}
for name, (library_name, class_name) in config_dict.items():
importable_classes = LOADABLE_CLASSES[library_name]
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
load_method = getattr(class_obj, load_method_name)
loaded_sub_model = load_method(os.path.join(pretrained_model_name_or_path, name))
init_kwargs[name] = loaded_sub_model
model = cls(**init_kwargs)
return model