mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Make dynamo wrapped modules work with save_pretrained (#2726)
* Workaround for saving dynamo-wrapped models. * Accept suggestion from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply workaround when overriding pipeline components. * Ensure the correct config.json is saved to disk. Instead of the dynamo class. * Save correct module (not compiled one) * Add test * style * fix docstrings * Go back to using string comparisons. PyTorch CPU does not have _dynamo. * Simple test for save_pretrained of compiled models. * Helper function to test whether module is compiled. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -50,6 +50,7 @@ from ..utils import (
|
||||
get_class_from_dynamic_module,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
@@ -255,7 +256,14 @@ def maybe_raise_or_warn(
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||||
# Dynamo wraps the original model in a private class.
|
||||
# I didn't find a public API to get the original class.
|
||||
sub_model = passed_class_obj[name]
|
||||
model_cls = sub_model.__class__
|
||||
if is_compiled_module(sub_model):
|
||||
model_cls = sub_model._orig_mod.__class__
|
||||
|
||||
if not issubclass(model_cls, expected_class_obj):
|
||||
raise ValueError(
|
||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
@@ -419,6 +427,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if module is None:
|
||||
register_dict = {name: (None, None)}
|
||||
else:
|
||||
# register the original module, not the dynamo compiled one
|
||||
if is_compiled_module(module):
|
||||
module = module._orig_mod
|
||||
|
||||
library = module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
@@ -484,6 +496,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
# Dynamo wraps the original model in a private class.
|
||||
# I didn't find a public API to get the original class.
|
||||
if is_compiled_module(sub_model):
|
||||
sub_model = sub_model._orig_mod
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
# search for the model's base class in LOADABLE_CLASSES
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
|
||||
@@ -74,7 +74,7 @@ from .import_utils import (
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .pil_utils import PIL_INTERPOLATION
|
||||
from .torch_utils import randn_tensor
|
||||
from .torch_utils import is_compiled_module, randn_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -86,6 +86,7 @@ if is_torch_available():
|
||||
nightly,
|
||||
parse_flag_from_env,
|
||||
print_tensor_test,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
|
||||
@@ -25,6 +25,7 @@ from .import_utils import (
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
)
|
||||
from .logging import get_logger
|
||||
|
||||
@@ -165,6 +166,15 @@ def require_torch(test_case):
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_2(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
|
||||
|
||||
@@ -17,7 +17,7 @@ PyTorch utilities: Utilities related to PyTorch
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from . import logging
|
||||
from .import_utils import is_torch_available
|
||||
from .import_utils import is_torch_available, is_torch_version
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -68,3 +68,10 @@ def randn_tensor(
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
def is_compiled_module(module):
|
||||
"""Check whether the module was compiled with torch.compile()"""
|
||||
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
|
||||
return False
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
@@ -27,6 +27,7 @@ from requests.exceptions import HTTPError
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
@@ -167,6 +168,21 @@ class ModelTesterMixin:
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
@require_torch_gpu
|
||||
def test_from_save_pretrained_dynamo(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model = torch.compile(model)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
|
||||
assert new_model.__class__ == self.model_class
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -54,7 +54,16 @@ from diffusers import (
|
||||
logging,
|
||||
)
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
|
||||
from diffusers.utils import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
floats_tensor,
|
||||
is_flax_available,
|
||||
nightly,
|
||||
require_torch_2,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
|
||||
|
||||
|
||||
@@ -966,9 +975,41 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
schedular = DDPMScheduler(num_train_timesteps=10)
|
||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
ddpm = DDPMPipeline(model, schedular)
|
||||
ddpm = DDPMPipeline(model, scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ddpm.save_pretrained(tmpdirname)
|
||||
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
|
||||
new_ddpm.to(torch_device)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@require_torch_2
|
||||
def test_from_save_pretrained_dynamo(self):
|
||||
# 1. Load models
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
model = torch.compile(model)
|
||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
ddpm = DDPMPipeline(model, scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user