1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix mask discrepancies in train_dreambooth_inpaint (#1529)

The mask and instance image were being cropped in different ways without --center_crop, causing the model to learn to ignore the mask in some cases. This PR fixes that and generate more consistent results.
This commit is contained in:
Adalberto
2022-12-05 13:26:36 -03:00
committed by GitHub
parent 634be6e53d
commit e289998932

View File

@@ -295,10 +295,15 @@ class DreamBoothDataset(Dataset):
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
self.image_transforms_resize_and_crop = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
]
)
self.image_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
@@ -312,6 +317,7 @@ class DreamBoothDataset(Dataset):
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
instance_image = self.image_transforms_resize_and_crop(instance_image)
example["PIL_images"] = instance_image
example["instance_images"] = self.image_transforms(instance_image)
@@ -327,6 +333,7 @@ class DreamBoothDataset(Dataset):
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
class_image = self.image_transforms_resize_and_crop(class_image)
example["class_images"] = self.image_transforms(class_image)
example["class_PIL_images"] = class_image
example["class_prompt_ids"] = self.tokenizer(
@@ -513,12 +520,6 @@ def main():
)
def collate_fn(examples):
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
]
)
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
@@ -535,9 +536,6 @@ def main():
pil_image = example["PIL_images"]
# generate a random mask
mask = random_mask(pil_image.size, 1, False)
# apply transforms
mask = image_transforms(mask)
pil_image = image_transforms(pil_image)
# prepare mask and masked image
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
@@ -548,9 +546,6 @@ def main():
for pil_image in pior_pil:
# generate a random mask
mask = random_mask(pil_image.size, 1, False)
# apply transforms
mask = image_transforms(mask)
pil_image = image_transforms(pil_image)
# prepare mask and masked image
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)