1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2024-09-01 02:35:19 +02:00
parent 588c6ee602
commit 74e6f90097

View File

@@ -1235,7 +1235,6 @@ def main(args):
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
for epoch in range(first_epoch, args.num_train_epochs):
print("epoch:", epoch)
transformer.train()
if args.train_text_encoder:
text_encoder.train()
@@ -1275,7 +1274,6 @@ def main(args):
0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
)
timesteps = timesteps.long()
print(model_input.shape, timesteps, prompt_embeds.shape)
# Prepare rotary embeds
image_rotary_emb = (
@@ -1291,7 +1289,6 @@ def main(args):
if transformer.config.use_rotary_positional_embeddings
else None
)
print(image_rotary_emb)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
@@ -1306,26 +1303,26 @@ def main(args):
return_dict=False,
)[0]
# # =====
# weights = 1 / (1 - scheduler.alphas_cumprod[timesteps])
# weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps
# print(weights)
# while len(weights.shape) < len(model_pred.shape):
# weights = weights.unsqueeze(-1)
# # =====
# =====
weights = 1 / (1 - scheduler.alphas_cumprod[timesteps])
print(timesteps, weights)
weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps
while len(weights.shape) < len(model_pred.shape):
weights = weights.unsqueeze(-1)
# =====
# target = model_input
target = model_input
if scheduler.config.prediction_type == "epsilon":
target = noise
elif scheduler.config.prediction_type == "v_prediction":
target = scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
# if scheduler.config.prediction_type == "epsilon":
# target = noise
# elif scheduler.config.prediction_type == "v_prediction":
# target = scheduler.get_velocity(model_input, noise, timesteps)
# else:
# raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
# loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# loss = torch.mean((weights * (model_pred - model_input) ** 2).reshape(batch_size, -1), dim=1)
loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
# loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
accelerator.backward(loss)
if accelerator.sync_gradients: