From 81125d8499b82da80e997c45c72ea54ebd8b8abb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 28 Mar 2023 09:03:21 +0200 Subject: [PATCH] Make dynamo wrapped modules work with save_pretrained (#2726) * Workaround for saving dynamo-wrapped models. * Accept suggestion from code review Co-authored-by: Patrick von Platen * Apply workaround when overriding pipeline components. * Ensure the correct config.json is saved to disk. Instead of the dynamo class. * Save correct module (not compiled one) * Add test * style * fix docstrings * Go back to using string comparisons. PyTorch CPU does not have _dynamo. * Simple test for save_pretrained of compiled models. * Helper function to test whether module is compiled. --------- Co-authored-by: Patrick von Platen --- src/diffusers/pipelines/pipeline_utils.py | 20 +++++++++- src/diffusers/utils/__init__.py | 3 +- src/diffusers/utils/testing_utils.py | 10 +++++ src/diffusers/utils/torch_utils.py | 9 ++++- tests/test_modeling_common.py | 16 ++++++++ tests/test_pipelines.py | 47 +++++++++++++++++++++-- 6 files changed, 99 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d3578745b8..a03c454e92 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -50,6 +50,7 @@ from ..utils import ( get_class_from_dynamic_module, is_accelerate_available, is_accelerate_version, + is_compiled_module, is_safetensors_available, is_torch_version, is_transformers_available, @@ -255,7 +256,14 @@ def maybe_raise_or_warn( if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + sub_model = passed_class_obj[name] + model_cls = sub_model.__class__ + if is_compiled_module(sub_model): + model_cls = sub_model._orig_mod.__class__ + + if not issubclass(model_cls, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" @@ -419,6 +427,10 @@ class DiffusionPipeline(ConfigMixin): if module is None: register_dict = {name: (None, None)} else: + # register the original module, not the dynamo compiled one + if is_compiled_module(module): + module = module._orig_mod + library = module.__module__.split(".")[0] # check if the module is a pipeline module @@ -484,6 +496,12 @@ class DiffusionPipeline(ConfigMixin): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + if is_compiled_module(sub_model): + sub_model = sub_model._orig_mod + model_cls = sub_model.__class__ + save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 14e975c487..615804c91a 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -74,7 +74,7 @@ from .import_utils import ( from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION -from .torch_utils import randn_tensor +from .torch_utils import is_compiled_module, randn_tensor if is_torch_available(): @@ -86,6 +86,7 @@ if is_torch_available(): nightly, parse_flag_from_env, print_tensor_test, + require_torch_2, require_torch_gpu, skip_mps, slow, diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index bf8109ae5c..afea0540b7 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -25,6 +25,7 @@ from .import_utils import ( is_onnx_available, is_opencv_available, is_torch_available, + is_torch_version, ) from .logging import get_logger @@ -165,6 +166,15 @@ def require_torch(test_case): return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) +def require_torch_2(test_case): + """ + Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. + """ + return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( + test_case + ) + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 113e64c16b..b9815cbcee 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -17,7 +17,7 @@ PyTorch utilities: Utilities related to PyTorch from typing import List, Optional, Tuple, Union from . import logging -from .import_utils import is_torch_available +from .import_utils import is_torch_available, is_torch_version if is_torch_available(): @@ -68,3 +68,10 @@ def randn_tensor( latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents + + +def is_compiled_module(module): + """Check whether the module was compiled with torch.compile()""" + if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): + return False + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 932c147027..1c45ce11b8 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -27,6 +27,7 @@ from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel from diffusers.training_utils import EMAModel from diffusers.utils import torch_device +from diffusers.utils.testing_utils import require_torch_gpu class ModelUtilsTest(unittest.TestCase): @@ -167,6 +168,21 @@ class ModelTesterMixin: max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + @require_torch_gpu + def test_from_save_pretrained_dynamo(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model = torch.compile(model) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + assert new_model.__class__ == self.model_class + def test_from_save_pretrained_dtype(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index cb5984885c..2616223c54 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -54,7 +54,16 @@ from diffusers import ( logging, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device +from diffusers.utils import ( + CONFIG_NAME, + WEIGHTS_NAME, + floats_tensor, + is_flax_available, + nightly, + require_torch_2, + slow, + torch_device, +) from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu @@ -966,9 +975,41 @@ class PipelineSlowTests(unittest.TestCase): down_block_types=("DownBlock2D", "AttnDownBlock2D"), up_block_types=("AttnUpBlock2D", "UpBlock2D"), ) - schedular = DDPMScheduler(num_train_timesteps=10) + scheduler = DDPMScheduler(num_train_timesteps=10) - ddpm = DDPMPipeline(model, schedular) + ddpm = DDPMPipeline(model, scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) + new_ddpm.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + generator = torch.Generator(device=torch_device).manual_seed(0) + new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + + @require_torch_2 + def test_from_save_pretrained_dynamo(self): + # 1. Load models + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + model = torch.compile(model) + scheduler = DDPMScheduler(num_train_timesteps=10) + + ddpm = DDPMPipeline(model, scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None)