From 6eaebe82787f59a5245e5741add0a893bdf352f5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 19:50:25 +0530 Subject: [PATCH] [Utils] Adds `store()` and `restore()` methods to EMAModel (#2302) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add store and restore() methods to EMAModel. * Update src/diffusers/training_utils.py Co-authored-by: Patrick von Platen * make style with doc builder * remove explicit listing. * Apply suggestions from code review Co-authored-by: Will Berman * Apply suggestions from code review Co-authored-by: Patrick von Platen * chore: better variable naming. * better treatment of temp_stored_params Co-authored-by: patil-suraj * make style * remove temporary params from earth 🌎 * make fix-copies. --------- Co-authored-by: Patrick von Platen Co-authored-by: Will Berman Co-authored-by: patil-suraj --- src/diffusers/training_utils.py | 40 +++++++++++++------ .../dummy_torch_and_transformers_objects.py | 4 +- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c77ea03adf..67a8e48d38 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -115,7 +115,7 @@ class EMAModel: deprecate("device", "1.0.0", deprecation_message, standard_warn=False) self.to(device=kwargs["device"]) - self.collected_params = None + self.temp_stored_params = None self.decay = decay self.min_decay = min_decay @@ -149,7 +149,6 @@ class EMAModel: model = self.model_cls.from_config(self.model_config) state_dict = self.state_dict() state_dict.pop("shadow_params", None) - state_dict.pop("collected_params", None) model.register_to_config(**state_dict) self.copy_to(model.parameters()) @@ -248,9 +247,35 @@ class EMAModel: "inv_gamma": self.inv_gamma, "power": self.power, "shadow_params": self.shadow_params, - "collected_params": self.collected_params, } + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Save the current parameters for restoring later. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: + affecting the original optimization process. Store the parameters before the `copy_to()` method. After + validation (or model saving), use this to restore the former parameters. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + if self.temp_stored_params is None: + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) + + # Better memory-wise. + self.temp_stored_params = None + def load_state_dict(self, state_dict: dict) -> None: r""" Args: @@ -297,12 +322,3 @@ class EMAModel: raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict.get("collected_params", None) - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a1394292d7..6b8ddd2a0e 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -212,7 +212,7 @@ class StableDiffusionPipelineSafe(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionSAGPipeline(metaclass=DummyObject): +class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -227,7 +227,7 @@ class StableDiffusionSAGPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): +class StableDiffusionSAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs):