diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 4cf553c7b1..ebe5eb9826 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -4,7 +4,7 @@ import os import torch import torch.nn.functional as F -from accelerate import Accelerator +from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate.logging import get_logger from datasets import load_dataset from diffusers import DDIMPipeline, DDIMScheduler, UNetModel @@ -27,8 +27,14 @@ logger = get_logger(__name__) def main(args): + ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True) logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator = Accelerator(mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir) + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + kwargs_handlers=[ddp_unused_params], + ) model = UNetModel( attn_resolutions=(16,), diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7e8af27e23..5f86993cfa 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -70,7 +70,7 @@ class AttentionBlock(nn.Module): if encoder_channels is not None: self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1) - self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + self.proj = zero_module(nn.Conv1d(channels, channels, 1)) self.overwrite_qkv = overwrite_qkv if overwrite_qkv: @@ -108,15 +108,15 @@ class AttentionBlock(nn.Module): proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] proj_out.bias.data = module.proj_out.bias.data - self.proj_out = proj_out + self.proj = proj_out elif self.overwrite_linear: self.qkv.weight.data = torch.concat( [self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0 )[:, :, None] self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) - self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None] - self.proj_out.bias.data = self.NIN_3.b.data + self.proj.weight.data = self.NIN_3.W.data.T[:, :, None] + self.proj.bias.data = self.NIN_3.b.data self.norm.weight.data = self.GroupNorm_0.weight.data self.norm.bias.data = self.GroupNorm_0.bias.data @@ -149,7 +149,7 @@ class AttentionBlock(nn.Module): a = torch.einsum("bts,bcs->bct", weight, v) h = a.reshape(bs, -1, length) - h = self.proj_out(h) + h = self.proj(h) h = h.reshape(b, c, *spatial) result = x + h diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index fed46706b5..45be9acc6a 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -30,7 +30,7 @@ class EMAModel: min_value (float): The minimum EMA decay rate. Default: 0. """ - self.averaged_model = copy.deepcopy(model) + self.averaged_model = copy.deepcopy(model).eval() self.averaged_model.requires_grad_(False) self.update_after_step = update_after_step diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 420aea2ac3..aee90dbd9d 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -52,6 +52,7 @@ from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.bddm.pipeline_bddm import DiffWave from diffusers.testing_utils import floats_tensor, slow, torch_device +from diffusers.training_utils import EMAModel torch.backends.cuda.matmul.allow_tf32 = False @@ -197,6 +198,20 @@ class ModelTesterMixin: loss = torch.nn.functional.mse_loss(output, noise) loss.backward() + def test_ema_training(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + ema_model = EMAModel(model, device=torch_device) + + output = model(**inputs_dict) + noise = torch.randn((inputs_dict["x"].shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + ema_model.step(model) + class UnetModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNetModel