1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Make dynamo wrapped modules work with save_pretrained (#2726)

* Workaround for saving dynamo-wrapped models.

* Accept suggestion from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Apply workaround when overriding pipeline components.

* Ensure the correct config.json is saved to disk.

Instead of the dynamo class.

* Save correct module (not compiled one)

* Add test

* style

* fix docstrings

* Go back to using string comparisons.

PyTorch CPU does not have _dynamo.

* Simple test for save_pretrained of compiled models.

* Helper function to test whether module is compiled.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Pedro Cuenca
2023-03-28 09:03:21 +02:00
committed by GitHub
parent d4f846fa74
commit 81125d8499
6 changed files with 99 additions and 6 deletions

View File

@@ -50,6 +50,7 @@ from ..utils import (
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
is_safetensors_available,
is_torch_version,
is_transformers_available,
@@ -255,7 +256,14 @@ def maybe_raise_or_warn(
if class_candidate is not None and issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
sub_model = passed_class_obj[name]
model_cls = sub_model.__class__
if is_compiled_module(sub_model):
model_cls = sub_model._orig_mod.__class__
if not issubclass(model_cls, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
@@ -419,6 +427,10 @@ class DiffusionPipeline(ConfigMixin):
if module is None:
register_dict = {name: (None, None)}
else:
# register the original module, not the dynamo compiled one
if is_compiled_module(module):
module = module._orig_mod
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
@@ -484,6 +496,12 @@ class DiffusionPipeline(ConfigMixin):
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
if is_compiled_module(sub_model):
sub_model = sub_model._orig_mod
model_cls = sub_model.__class__
save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():

View File

@@ -74,7 +74,7 @@ from .import_utils import (
from .logging import get_logger
from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION
from .torch_utils import randn_tensor
from .torch_utils import is_compiled_module, randn_tensor
if is_torch_available():
@@ -86,6 +86,7 @@ if is_torch_available():
nightly,
parse_flag_from_env,
print_tensor_test,
require_torch_2,
require_torch_gpu,
skip_mps,
slow,

View File

@@ -25,6 +25,7 @@ from .import_utils import (
is_onnx_available,
is_opencv_available,
is_torch_available,
is_torch_version,
)
from .logging import get_logger
@@ -165,6 +166,15 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
def require_torch_2(test_case):
"""
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
"""
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
test_case
)
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(

View File

@@ -17,7 +17,7 @@ PyTorch utilities: Utilities related to PyTorch
from typing import List, Optional, Tuple, Union
from . import logging
from .import_utils import is_torch_available
from .import_utils import is_torch_available, is_torch_version
if is_torch_available():
@@ -68,3 +68,10 @@ def randn_tensor(
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents
def is_compiled_module(module):
"""Check whether the module was compiled with torch.compile()"""
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)

View File

@@ -27,6 +27,7 @@ from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel
from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import require_torch_gpu
class ModelUtilsTest(unittest.TestCase):
@@ -167,6 +168,21 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
@require_torch_gpu
def test_from_save_pretrained_dynamo(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model = torch.compile(model)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
assert new_model.__class__ == self.model_class
def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@@ -54,7 +54,16 @@ from diffusers import (
logging,
)
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
from diffusers.utils import (
CONFIG_NAME,
WEIGHTS_NAME,
floats_tensor,
is_flax_available,
nightly,
require_torch_2,
slow,
torch_device,
)
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
@@ -966,9 +975,41 @@ class PipelineSlowTests(unittest.TestCase):
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
schedular = DDPMScheduler(num_train_timesteps=10)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, schedular)
ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@require_torch_2
def test_from_save_pretrained_dynamo(self):
# 1. Load models
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
model = torch.compile(model)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)