mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* fix bug in micro-conditioning of class images * fix bug in micro-conditioning of class images * style
This commit is contained in:
@@ -939,6 +939,32 @@ class DreamBoothDataset(Dataset):
|
||||
self.class_data_root = Path(class_data_root)
|
||||
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
||||
self.class_images_path = list(self.class_data_root.iterdir())
|
||||
|
||||
self.original_sizes_class_imgs = []
|
||||
self.crop_top_lefts_class_imgs = []
|
||||
self.pixel_values_class_imgs = []
|
||||
self.class_images = [Image.open(path) for path in self.class_images_path]
|
||||
for image in self.class_images:
|
||||
image = exif_transpose(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
self.original_sizes_class_imgs.append((image.height, image.width))
|
||||
image = train_resize(image)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
crop_top_left = (y1, x1)
|
||||
self.crop_top_lefts_class_imgs.append(crop_top_left)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values_class_imgs.append(image)
|
||||
|
||||
if class_num is not None:
|
||||
self.num_class_images = min(len(self.class_images_path), class_num)
|
||||
else:
|
||||
@@ -961,12 +987,9 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = self.pixel_values[index % self.num_instance_images]
|
||||
original_size = self.original_sizes[index % self.num_instance_images]
|
||||
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["original_size"] = original_size
|
||||
example["crop_top_left"] = crop_top_left
|
||||
example["instance_images"] = self.pixel_values[index % self.num_instance_images]
|
||||
example["original_size"] = self.original_sizes[index % self.num_instance_images]
|
||||
example["crop_top_left"] = self.crop_top_lefts[index % self.num_instance_images]
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
@@ -983,13 +1006,10 @@ class DreamBoothDataset(Dataset):
|
||||
example["instance_prompt"] = self.instance_prompt
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
||||
class_image = exif_transpose(class_image)
|
||||
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_prompt"] = self.class_prompt
|
||||
example["class_images"] = self.pixel_values_class_imgs[index % self.num_class_images]
|
||||
example["class_original_size"] = self.original_sizes_class_imgs[index % self.num_class_images]
|
||||
example["class_crop_top_left"] = self.crop_top_lefts_class_imgs[index % self.num_class_images]
|
||||
|
||||
return example
|
||||
|
||||
@@ -1005,6 +1025,8 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
if with_prior_preservation:
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
prompts += [example["class_prompt"] for example in examples]
|
||||
original_sizes += [example["class_original_size"] for example in examples]
|
||||
crop_top_lefts += [example["class_crop_top_left"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
Reference in New Issue
Block a user