diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 19c6162448..d3f8413415 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -415,33 +415,23 @@ def main(): ) def collate_fn(examples): - def _collate(input_ids, pixel_values): - pixel_values = torch.stack([pixel_value for pixel_value in pixel_values]) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] - input_ids = [input_id for input_id in input_ids] - input_ids = tokenizer.pad( - {"input_ids": input_ids}, - padding=True, - return_tensors="pt", - ).input_ids - return input_ids, pixel_values + # concat class and instance examples for prior preservation + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] - instance_prompt_ids = [example["instance_prompt_ids"] for example in examples] - instance_images = [example["instance_images"] for example in examples] - instance_prompt_ids, instance_images = _collate(instance_prompt_ids, instance_images) + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids batch = { - "instance_images": instance_images, - "instance_prompt_ids": instance_prompt_ids, + "input_ids": input_ids, + "pixel_values": pixel_values, } - - if args.with_prior_preservation: - class_prompt_ids = [example["class_prompt_ids"] for example in examples] - class_images = [example["class_images"] for example in examples] - class_prompt_ids, class_images = _collate(class_prompt_ids, class_images) - batch["class_images"] = class_images - batch["class_prompt_ids"] = class_prompt_ids return batch train_dataloader = torch.utils.data.DataLoader( @@ -503,15 +493,8 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - if args.with_prior_preservation: - images = torch.cat([batch["instance_images"], batch["class_images"]], dim=0) - input_ids = torch.cat([batch["instance_prompt_ids"], batch["class_prompt_ids"]], dim=0) - else: - images = batch["instance_images"] - input_ids = batch["instance_prompt_ids"] - with torch.no_grad(): - latents = vae.encode(images).latent_dist.sample() + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -528,7 +511,7 @@ def main(): # Get the text embedding for conditioning with torch.no_grad(): - encoder_hidden_states = text_encoder(input_ids)[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample