1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Utils] Adds store() and restore() methods to EMAModel (#2302)

* add store and restore() methods to EMAModel.

* Update src/diffusers/training_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style with doc builder

* remove explicit listing.

* Apply suggestions from code review

Co-authored-by: Will Berman <wlbberman@gmail.com>

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* chore: better variable naming.

* better treatment of temp_stored_params

Co-authored-by: patil-suraj <surajp815@gmail.com>

* make style

* remove temporary params from earth 🌎

* make fix-copies.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
Sayak Paul
2023-02-16 19:50:25 +05:30
committed by GitHub
parent b214bb25f8
commit 6eaebe8278
2 changed files with 30 additions and 14 deletions

View File

@@ -115,7 +115,7 @@ class EMAModel:
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
self.to(device=kwargs["device"])
self.collected_params = None
self.temp_stored_params = None
self.decay = decay
self.min_decay = min_decay
@@ -149,7 +149,6 @@ class EMAModel:
model = self.model_cls.from_config(self.model_config)
state_dict = self.state_dict()
state_dict.pop("shadow_params", None)
state_dict.pop("collected_params", None)
model.register_to_config(**state_dict)
self.copy_to(model.parameters())
@@ -248,9 +247,35 @@ class EMAModel:
"inv_gamma": self.inv_gamma,
"power": self.power,
"shadow_params": self.shadow_params,
"collected_params": self.collected_params,
}
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Save the current parameters for restoring later.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
validation (or model saving), use this to restore the former parameters.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the parameters with which this
`ExponentialMovingAverage` was initialized will be used.
"""
if self.temp_stored_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
for c_param, param in zip(self.temp_stored_params, parameters):
param.data.copy_(c_param.data)
# Better memory-wise.
self.temp_stored_params = None
def load_state_dict(self, state_dict: dict) -> None:
r"""
Args:
@@ -297,12 +322,3 @@ class EMAModel:
raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")
self.collected_params = state_dict.get("collected_params", None)
if self.collected_params is not None:
if not isinstance(self.collected_params, list):
raise ValueError("collected_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
raise ValueError("collected_params must all be Tensors")
if len(self.collected_params) != len(self.shadow_params):
raise ValueError("collected_params and shadow_params must have the same length")

View File

@@ -212,7 +212,7 @@ class StableDiffusionPipelineSafe(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionSAGPipeline(metaclass=DummyObject):
class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -227,7 +227,7 @@ class StableDiffusionSAGPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
class StableDiffusionSAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):