1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix mutable proj_out weight in the Attention layer (#73)

* Catch unused params in DDP

* Fix proj_out, add test
This commit is contained in:
Anton Lozhkov
2022-07-04 12:36:37 +02:00
committed by GitHub
parent 3abf4bc439
commit d9316bf8bc
4 changed files with 29 additions and 8 deletions

View File

@@ -4,7 +4,7 @@ import os
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
@@ -27,8 +27,14 @@ logger = get_logger(__name__)
def main(args):
ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
kwargs_handlers=[ddp_unused_params],
)
model = UNetModel(
attn_resolutions=(16,),

View File

@@ -70,7 +70,7 @@ class AttentionBlock(nn.Module):
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
if overwrite_qkv:
@@ -108,15 +108,15 @@ class AttentionBlock(nn.Module):
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj_out = proj_out
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj_out.bias.data = self.NIN_3.b.data
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
@@ -149,7 +149,7 @@ class AttentionBlock(nn.Module):
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj_out(h)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h

View File

@@ -30,7 +30,7 @@ class EMAModel:
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = copy.deepcopy(model)
self.averaged_model = copy.deepcopy(model).eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step

View File

@@ -52,6 +52,7 @@ from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.bddm.pipeline_bddm import DiffWave
from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel
torch.backends.cuda.matmul.allow_tf32 = False
@@ -197,6 +198,20 @@ class ModelTesterMixin:
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
def test_ema_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
ema_model = EMAModel(model, device=torch_device)
output = model(**inputs_dict)
noise = torch.randn((inputs_dict["x"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model)
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetModel