From b8b5daaee30ecbecd7b901020008ffead443665d Mon Sep 17 00:00:00 2001 From: Ambrosiussen Date: Mon, 22 May 2023 16:49:35 +0200 Subject: [PATCH] DataLoader respecting EXIF data in Training Images (#3465) * DataLoader will now bake in any transforms or image manipulations contained in the EXIF Images may have rotations stored in EXIF. Training using such images will cause those transforms to be ignored while training and thus produce unexpected results * Fixed the Dataloading EXIF issue in main DreamBooth training as well * Run make style (black & isort) --- examples/dreambooth/train_dreambooth.py | 23 ++++++++++++-------- examples/dreambooth/train_dreambooth_lora.py | 23 ++++++++++++-------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index efcfb39ab4..53d9c269f3 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -27,19 +27,13 @@ import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint +from torch.utils.data import Dataset + +import diffusers import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, model_info, upload_folder -from packaging import version -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig - -import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, @@ -50,6 +44,13 @@ from diffusers import ( from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import create_repo, model_info, upload_folder +from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig if is_wandb_available(): @@ -607,6 +608,8 @@ class DreamBoothDataset(Dataset): def __getitem__(self, index): example = {} instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) @@ -622,6 +625,8 @@ class DreamBoothDataset(Dataset): if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index bfbf3603e8..659b0d3e1d 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -27,19 +27,13 @@ import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint +from torch.utils.data import Dataset + +import diffusers import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder -from packaging import version -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig - -import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, @@ -59,6 +53,13 @@ from diffusers.models.attention_processor import ( from diffusers.optimization import get_scheduler from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -508,6 +509,8 @@ class DreamBoothDataset(Dataset): def __getitem__(self, index): example = {} instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) @@ -523,6 +526,8 @@ class DreamBoothDataset(Dataset): if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image)