mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix mask discrepancies in train_dreambooth_inpaint (#1529)
The mask and instance image were being cropped in different ways without --center_crop, causing the model to learn to ignore the mask in some cases. This PR fixes that and generate more consistent results.
This commit is contained in:
@@ -295,10 +295,15 @@ class DreamBoothDataset(Dataset):
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
self.image_transforms_resize_and_crop = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
]
|
||||
)
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
@@ -312,6 +317,7 @@ class DreamBoothDataset(Dataset):
|
||||
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
instance_image = self.image_transforms_resize_and_crop(instance_image)
|
||||
|
||||
example["PIL_images"] = instance_image
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
@@ -327,6 +333,7 @@ class DreamBoothDataset(Dataset):
|
||||
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
class_image = self.image_transforms_resize_and_crop(class_image)
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_PIL_images"] = class_image
|
||||
example["class_prompt_ids"] = self.tokenizer(
|
||||
@@ -513,12 +520,6 @@ def main():
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
]
|
||||
)
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
|
||||
@@ -535,9 +536,6 @@ def main():
|
||||
pil_image = example["PIL_images"]
|
||||
# generate a random mask
|
||||
mask = random_mask(pil_image.size, 1, False)
|
||||
# apply transforms
|
||||
mask = image_transforms(mask)
|
||||
pil_image = image_transforms(pil_image)
|
||||
# prepare mask and masked image
|
||||
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
|
||||
|
||||
@@ -548,9 +546,6 @@ def main():
|
||||
for pil_image in pior_pil:
|
||||
# generate a random mask
|
||||
mask = random_mask(pil_image.size, 1, False)
|
||||
# apply transforms
|
||||
mask = image_transforms(mask)
|
||||
pil_image = image_transforms(pil_image)
|
||||
# prepare mask and masked image
|
||||
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user