mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Examples] Uniform notations in train_flux_lora (#10011)
[Examples] uniform naming notations since the in parameter `size` represents `args.resolution`, I thus replace the `args.resolution` inside DreamBoothData with `size`. And revise some notations such as `center_crop`. Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
@@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
instance_data_root,
|
||||
instance_prompt,
|
||||
class_prompt,
|
||||
@@ -980,10 +981,8 @@ class DreamBoothDataset(Dataset):
|
||||
class_num=None,
|
||||
size=1024,
|
||||
repeats=1,
|
||||
center_crop=False,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
|
||||
self.instance_prompt = instance_prompt
|
||||
self.custom_instance_prompts = None
|
||||
@@ -1075,11 +1074,11 @@ class DreamBoothDataset(Dataset):
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
y1 = max(0, int(round((image.height - self.size) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - self.size) / 2.0)))
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values.append(image)
|
||||
@@ -1827,6 +1826,7 @@ def main(args):
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
args=args,
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
train_text_encoder_ti=args.train_text_encoder_ti,
|
||||
@@ -1836,7 +1836,6 @@ def main(args):
|
||||
class_num=args.num_class_images,
|
||||
size=args.resolution,
|
||||
repeats=args.repeats,
|
||||
center_crop=args.center_crop,
|
||||
)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
|
||||
Reference in New Issue
Block a user