mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Utils] Adds store() and restore() methods to EMAModel (#2302)
* add store and restore() methods to EMAModel. * Update src/diffusers/training_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style with doc builder * remove explicit listing. * Apply suggestions from code review Co-authored-by: Will Berman <wlbberman@gmail.com> * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * chore: better variable naming. * better treatment of temp_stored_params Co-authored-by: patil-suraj <surajp815@gmail.com> * make style * remove temporary params from earth 🌎 * make fix-copies. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
@@ -115,7 +115,7 @@ class EMAModel:
|
||||
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
|
||||
self.to(device=kwargs["device"])
|
||||
|
||||
self.collected_params = None
|
||||
self.temp_stored_params = None
|
||||
|
||||
self.decay = decay
|
||||
self.min_decay = min_decay
|
||||
@@ -149,7 +149,6 @@ class EMAModel:
|
||||
model = self.model_cls.from_config(self.model_config)
|
||||
state_dict = self.state_dict()
|
||||
state_dict.pop("shadow_params", None)
|
||||
state_dict.pop("collected_params", None)
|
||||
|
||||
model.register_to_config(**state_dict)
|
||||
self.copy_to(model.parameters())
|
||||
@@ -248,9 +247,35 @@ class EMAModel:
|
||||
"inv_gamma": self.inv_gamma,
|
||||
"power": self.power,
|
||||
"shadow_params": self.shadow_params,
|
||||
"collected_params": self.collected_params,
|
||||
}
|
||||
|
||||
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
||||
r"""
|
||||
Args:
|
||||
Save the current parameters for restoring later.
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
||||
r"""
|
||||
Args:
|
||||
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
|
||||
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
||||
validation (or model saving), use this to restore the former parameters.
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters. If `None`, the parameters with which this
|
||||
`ExponentialMovingAverage` was initialized will be used.
|
||||
"""
|
||||
if self.temp_stored_params is None:
|
||||
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
|
||||
for c_param, param in zip(self.temp_stored_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
|
||||
# Better memory-wise.
|
||||
self.temp_stored_params = None
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
r"""
|
||||
Args:
|
||||
@@ -297,12 +322,3 @@ class EMAModel:
|
||||
raise ValueError("shadow_params must be a list")
|
||||
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
||||
raise ValueError("shadow_params must all be Tensors")
|
||||
|
||||
self.collected_params = state_dict.get("collected_params", None)
|
||||
if self.collected_params is not None:
|
||||
if not isinstance(self.collected_params, list):
|
||||
raise ValueError("collected_params must be a list")
|
||||
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
|
||||
raise ValueError("collected_params must all be Tensors")
|
||||
if len(self.collected_params) != len(self.shadow_params):
|
||||
raise ValueError("collected_params and shadow_params must have the same length")
|
||||
|
||||
@@ -212,7 +212,7 @@ class StableDiffusionPipelineSafe(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionSAGPipeline(metaclass=DummyObject):
|
||||
class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -227,7 +227,7 @@ class StableDiffusionSAGPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
|
||||
class StableDiffusionSAGPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user