mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add pipeline
This commit is contained in:
22
models/vision/ddpm/example.py
Executable file
22
models/vision/ddpm/example.py
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
from diffusers import UNetModel, GaussianDiffusion
|
||||
from modeling_ddpm import DDPM
|
||||
import tempfile
|
||||
|
||||
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
|
||||
|
||||
# compose Diffusion Pipeline
|
||||
ddpm = DDPM(unet, sampler)
|
||||
# generate / sample
|
||||
image = ddpm()
|
||||
print(image)
|
||||
|
||||
|
||||
# save and load with 0 extra code (handled by general `DiffusionPipeline` class)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ddpm.save_pretrained(tmpdirname)
|
||||
print("Model saved")
|
||||
ddpm_new = DDPM.from_pretrained(tmpdirname)
|
||||
print("Model loaded")
|
||||
print(ddpm_new)
|
||||
@@ -0,0 +1,27 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class DDPM(DiffusionPipeline):
|
||||
|
||||
def __init__(self, unet, gaussian_sampler):
|
||||
super().__init__(unet=unet, gaussian_sampler=gaussian_sampler)
|
||||
|
||||
def __call__(self, batch_size=1):
|
||||
image = self.gaussian_sampler.sample(self.unet, batch_size=batch_size)
|
||||
return image
|
||||
|
||||
@@ -6,3 +6,6 @@ __version__ = "0.0.1"
|
||||
|
||||
from .models.unet import UNetModel
|
||||
from .samplers.gaussian import GaussianDiffusion
|
||||
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .modeling_utils import PreTrainedModel
|
||||
|
||||
@@ -91,8 +91,8 @@ class Config:
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -198,6 +198,14 @@ class Config:
|
||||
f"Values will be initialized to default values."
|
||||
)
|
||||
|
||||
return config_dict, unused_kwargs
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
||||
):
|
||||
config_dict, unused_kwargs = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
model = cls(**config_dict)
|
||||
|
||||
if return_unused_kwargs:
|
||||
|
||||
108
src/diffusers/pipeline_utils.py
Normal file
108
src/diffusers/pipeline_utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
import importlib
|
||||
|
||||
from .configuration_utils import Config
|
||||
|
||||
# CHANGE to diffusers.utils
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_model.pt"
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"GaussianDiffusion": ["save_config", "from_config"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DiffusionPipeline(Config):
|
||||
|
||||
config_name = "model_index.json"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for name, module in kwargs.items():
|
||||
# retrive library
|
||||
library = module.__module__.split(".")[0]
|
||||
# retrive class_name
|
||||
class_name = module.__class__.__name__
|
||||
|
||||
# save model index config
|
||||
self.register(**{name: (library, class_name)})
|
||||
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
self.save_config(save_directory)
|
||||
|
||||
model_index_dict = self._dict_to_save
|
||||
model_index_dict.pop("_class_name")
|
||||
|
||||
for name, (library_name, class_name) in self._dict_to_save.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
|
||||
save_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
save_method_name = importable_classes[class_name][0]
|
||||
|
||||
save_method = getattr(getattr(self, name), save_method_name)
|
||||
save_method(os.path.join(save_directory, name))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path)
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
for name, (library_name, class_name) in config_dict.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
loaded_sub_model = load_method(os.path.join(pretrained_model_name_or_path, name))
|
||||
|
||||
init_kwargs[name] = loaded_sub_model
|
||||
|
||||
model = cls(**init_kwargs)
|
||||
return model
|
||||
Reference in New Issue
Block a user