mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix mutable proj_out weight in the Attention layer (#73)
* Catch unused params in DDP * Fix proj_out, add test
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
|
||||
@@ -27,8 +27,14 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def main(args):
|
||||
ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator = Accelerator(mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir)
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
logging_dir=logging_dir,
|
||||
kwargs_handlers=[ddp_unused_params],
|
||||
)
|
||||
|
||||
model = UNetModel(
|
||||
attn_resolutions=(16,),
|
||||
|
||||
@@ -70,7 +70,7 @@ class AttentionBlock(nn.Module):
|
||||
if encoder_channels is not None:
|
||||
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
|
||||
self.overwrite_qkv = overwrite_qkv
|
||||
if overwrite_qkv:
|
||||
@@ -108,15 +108,15 @@ class AttentionBlock(nn.Module):
|
||||
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
|
||||
proj_out.bias.data = module.proj_out.bias.data
|
||||
|
||||
self.proj_out = proj_out
|
||||
self.proj = proj_out
|
||||
elif self.overwrite_linear:
|
||||
self.qkv.weight.data = torch.concat(
|
||||
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
|
||||
)[:, :, None]
|
||||
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
|
||||
|
||||
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
|
||||
self.proj_out.bias.data = self.NIN_3.b.data
|
||||
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
|
||||
self.proj.bias.data = self.NIN_3.b.data
|
||||
|
||||
self.norm.weight.data = self.GroupNorm_0.weight.data
|
||||
self.norm.bias.data = self.GroupNorm_0.bias.data
|
||||
@@ -149,7 +149,7 @@ class AttentionBlock(nn.Module):
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
h = a.reshape(bs, -1, length)
|
||||
|
||||
h = self.proj_out(h)
|
||||
h = self.proj(h)
|
||||
h = h.reshape(b, c, *spatial)
|
||||
|
||||
result = x + h
|
||||
|
||||
@@ -30,7 +30,7 @@ class EMAModel:
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
|
||||
self.averaged_model = copy.deepcopy(model)
|
||||
self.averaged_model = copy.deepcopy(model).eval()
|
||||
self.averaged_model.requires_grad_(False)
|
||||
|
||||
self.update_after_step = update_after_step
|
||||
|
||||
@@ -52,6 +52,7 @@ from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.bddm.pipeline_bddm import DiffWave
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
@@ -197,6 +198,20 @@ class ModelTesterMixin:
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
def test_ema_training(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
ema_model = EMAModel(model, device=torch_device)
|
||||
|
||||
output = model(**inputs_dict)
|
||||
noise = torch.randn((inputs_dict["x"].shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
ema_model.step(model)
|
||||
|
||||
|
||||
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetModel
|
||||
|
||||
Reference in New Issue
Block a user