mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
concat batch in collate fn
This commit is contained in:
@@ -415,33 +415,23 @@ def main():
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
def _collate(input_ids, pixel_values):
|
||||
pixel_values = torch.stack([pixel_value for pixel_value in pixel_values])
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
|
||||
input_ids = [input_id for input_id in input_ids]
|
||||
input_ids = tokenizer.pad(
|
||||
{"input_ids": input_ids},
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
return input_ids, pixel_values
|
||||
# concat class and instance examples for prior preservation
|
||||
if args.with_prior_preservation:
|
||||
input_ids += [example["class_prompt_ids"] for example in examples]
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
|
||||
instance_prompt_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
instance_images = [example["instance_images"] for example in examples]
|
||||
instance_prompt_ids, instance_images = _collate(instance_prompt_ids, instance_images)
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
||||
|
||||
batch = {
|
||||
"instance_images": instance_images,
|
||||
"instance_prompt_ids": instance_prompt_ids,
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_prompt_ids = [example["class_prompt_ids"] for example in examples]
|
||||
class_images = [example["class_images"] for example in examples]
|
||||
class_prompt_ids, class_images = _collate(class_prompt_ids, class_images)
|
||||
batch["class_images"] = class_images
|
||||
batch["class_prompt_ids"] = class_prompt_ids
|
||||
return batch
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
@@ -503,15 +493,8 @@ def main():
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
if args.with_prior_preservation:
|
||||
images = torch.cat([batch["instance_images"], batch["class_images"]], dim=0)
|
||||
input_ids = torch.cat([batch["instance_prompt_ids"], batch["class_prompt_ids"]], dim=0)
|
||||
else:
|
||||
images = batch["instance_images"]
|
||||
input_ids = batch["instance_prompt_ids"]
|
||||
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(images).latent_dist.sample()
|
||||
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
@@ -528,7 +511,7 @@ def main():
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states = text_encoder(input_ids)[0]
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
Reference in New Issue
Block a user