mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -1235,7 +1235,6 @@ def main(args):
|
||||
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
print("epoch:", epoch)
|
||||
transformer.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
@@ -1275,7 +1274,6 @@ def main(args):
|
||||
0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
print(model_input.shape, timesteps, prompt_embeds.shape)
|
||||
|
||||
# Prepare rotary embeds
|
||||
image_rotary_emb = (
|
||||
@@ -1291,7 +1289,6 @@ def main(args):
|
||||
if transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
print(image_rotary_emb)
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
@@ -1306,26 +1303,26 @@ def main(args):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# # =====
|
||||
# weights = 1 / (1 - scheduler.alphas_cumprod[timesteps])
|
||||
# weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps
|
||||
# print(weights)
|
||||
# while len(weights.shape) < len(model_pred.shape):
|
||||
# weights = weights.unsqueeze(-1)
|
||||
# # =====
|
||||
# =====
|
||||
weights = 1 / (1 - scheduler.alphas_cumprod[timesteps])
|
||||
print(timesteps, weights)
|
||||
weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps
|
||||
while len(weights.shape) < len(model_pred.shape):
|
||||
weights = weights.unsqueeze(-1)
|
||||
# =====
|
||||
|
||||
# target = model_input
|
||||
target = model_input
|
||||
|
||||
if scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif scheduler.config.prediction_type == "v_prediction":
|
||||
target = scheduler.get_velocity(model_input, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
||||
# if scheduler.config.prediction_type == "epsilon":
|
||||
# target = noise
|
||||
# elif scheduler.config.prediction_type == "v_prediction":
|
||||
# target = scheduler.get_velocity(model_input, noise, timesteps)
|
||||
# else:
|
||||
# raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
||||
|
||||
# loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
# loss = torch.mean((weights * (model_pred - model_input) ** 2).reshape(batch_size, -1), dim=1)
|
||||
loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
|
||||
loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
|
||||
# loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
Reference in New Issue
Block a user