1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

concat batch in collate fn

This commit is contained in:
patil-suraj
2022-09-26 15:02:13 +02:00
parent ef01331146
commit c66cf4dc1a

View File

@@ -415,33 +415,23 @@ def main():
)
def collate_fn(examples):
def _collate(input_ids, pixel_values):
pixel_values = torch.stack([pixel_value for pixel_value in pixel_values])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
input_ids = [input_id for input_id in input_ids]
input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding=True,
return_tensors="pt",
).input_ids
return input_ids, pixel_values
# concat class and instance examples for prior preservation
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
instance_prompt_ids = [example["instance_prompt_ids"] for example in examples]
instance_images = [example["instance_images"] for example in examples]
instance_prompt_ids, instance_images = _collate(instance_prompt_ids, instance_images)
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
batch = {
"instance_images": instance_images,
"instance_prompt_ids": instance_prompt_ids,
"input_ids": input_ids,
"pixel_values": pixel_values,
}
if args.with_prior_preservation:
class_prompt_ids = [example["class_prompt_ids"] for example in examples]
class_images = [example["class_images"] for example in examples]
class_prompt_ids, class_images = _collate(class_prompt_ids, class_images)
batch["class_images"] = class_images
batch["class_prompt_ids"] = class_prompt_ids
return batch
train_dataloader = torch.utils.data.DataLoader(
@@ -503,15 +493,8 @@ def main():
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space
if args.with_prior_preservation:
images = torch.cat([batch["instance_images"], batch["class_images"]], dim=0)
input_ids = torch.cat([batch["instance_prompt_ids"], batch["class_prompt_ids"]], dim=0)
else:
images = batch["instance_images"]
input_ids = batch["instance_prompt_ids"]
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample()
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
@@ -528,7 +511,7 @@ def main():
# Get the text embedding for conditioning
with torch.no_grad():
encoder_hidden_states = text_encoder(input_ids)[0]
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample