diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 0813cadc63..4cf553c7b1 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -130,7 +130,7 @@ def main(args): torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() - ema_model.step(model, global_step) + ema_model.step(model) optimizer.zero_grad() progress_bar.update(1) progress_bar.set_postfix( diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index f81bf5cc03..fed46706b5 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -43,6 +43,7 @@ class EMAModel: self.averaged_model = self.averaged_model.to(device=device) self.decay = 0.0 + self.optimization_step = 0 def get_decay(self, optimization_step): """ @@ -57,11 +58,11 @@ class EMAModel: return max(self.min_value, min(value, self.max_value)) @torch.no_grad() - def step(self, new_model, optimization_step): + def step(self, new_model): ema_state_dict = {} ema_params = self.averaged_model.state_dict() - self.decay = self.get_decay(optimization_step) + self.decay = self.get_decay(self.optimization_step) for key, param in new_model.named_parameters(): if isinstance(param, dict): @@ -85,3 +86,4 @@ class EMAModel: ema_state_dict[key] = param self.averaged_model.load_state_dict(ema_state_dict, strict=False) + self.optimization_step += 1