1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/docs/source/ko/tutorials/basic_training.md
Sayak Paul 30e5e81d58 change to 2024 in the license (#6902)
change to 2024
2024-02-08 08:19:31 -10:00

18 KiB

open-in-colab

Diffusion ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๊ธฐ

Unconditional ์ด๋ฏธ์ง€ ์ƒ์„ฑ์€ ํ•™์Šต์— ์‚ฌ์šฉ๋œ ๋ฐ์ดํ„ฐ์…‹๊ณผ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” diffusion ๋ชจ๋ธ์—์„œ ์ธ๊ธฐ ์žˆ๋Š” ์–ดํ”Œ๋ฆฌ์ผ€์ด์…˜์ž…๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ, ๊ฐ€์žฅ ์ข‹์€ ๊ฒฐ๊ณผ๋Š” ํŠน์ • ๋ฐ์ดํ„ฐ์…‹์— ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ํ—ˆ๋ธŒ์—์„œ ์ด๋Ÿฌํ•œ ๋งŽ์€ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ์ง€๋งŒ, ๋งŒ์•ฝ ๋งˆ์Œ์— ๋“œ๋Š” ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ฐพ์ง€ ๋ชปํ–ˆ๋‹ค๋ฉด, ์–ธ์ œ๋“ ์ง€ ์Šค์Šค๋กœ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!

์ด ํŠœํ† ๋ฆฌ์–ผ์€ ๋‚˜๋งŒ์˜ ๐Ÿฆ‹ ๋‚˜๋น„ ๐Ÿฆ‹๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด Smithsonian Butterflies ๋ฐ์ดํ„ฐ์…‹์˜ ํ•˜์œ„ ์ง‘ํ•ฉ์—์„œ [UNet2DModel] ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๊ฐ€๋ฅด์ณ์ค„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๐Ÿ’ก ์ด ํ•™์Šต ํŠœํ† ๋ฆฌ์–ผ์€ Training with ๐Ÿงจ Diffusers ๋…ธํŠธ๋ถ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•ฉ๋‹ˆ๋‹ค. Diffusion ๋ชจ๋ธ์˜ ์ž‘๋™ ๋ฐฉ์‹ ๋ฐ ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋…ธํŠธ๋ถ์„ ํ™•์ธํ•˜์„ธ์š”!

์‹œ์ž‘ ์ „์—, ๐Ÿค— Datasets์„ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ์ „์ฒ˜๋ฆฌํ•˜๊ธฐ ์œ„ํ•ด ๋ฐ์ดํ„ฐ์…‹์ด ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ๋‹ค์ˆ˜ GPU์—์„œ ํ•™์Šต์„ ๊ฐ„์†Œํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๐Ÿค— Accelerate ๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”. ๊ทธ ํ›„ ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ์„ ์‹œ๊ฐํ™”ํ•˜๊ธฐ ์œ„ํ•ด TensorBoard๋ฅผ ๋˜ํ•œ ์„ค์น˜ํ•˜์„ธ์š”. (๋˜ํ•œ ํ•™์Šต ์ถ”์ ์„ ์œ„ํ•ด Weights & Biases๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)

!pip install diffusers[training]

์ปค๋ฎค๋‹ˆํ‹ฐ์— ๋ชจ๋ธ์„ ๊ณต์œ ํ•  ๊ฒƒ์„ ๊ถŒ์žฅํ•˜๋ฉฐ, ์ด๋ฅผ ์œ„ํ•ด์„œ Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธ์„ ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. (๊ณ„์ •์ด ์—†๋‹ค๋ฉด ์—ฌ๊ธฐ์—์„œ ๋งŒ๋“ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.) ๋…ธํŠธ๋ถ์—์„œ ๋กœ๊ทธ์ธํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ ๋ฉ”์‹œ์ง€๊ฐ€ ํ‘œ์‹œ๋˜๋ฉด ํ† ํฐ์„ ์ž…๋ ฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> from huggingface_hub import notebook_login

>>> notebook_login()

๋˜๋Š” ํ„ฐ๋ฏธ๋„๋กœ ๋กœ๊ทธ์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

huggingface-cli login

๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์ƒ๋‹นํžˆ ํฌ๊ธฐ ๋•Œ๋ฌธ์— Git-LFS์—์„œ ๋Œ€์šฉ๋Ÿ‰ ํŒŒ์ผ์˜ ๋ฒ„์ „ ๊ด€๋ฆฌ๋ฅผ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

!sudo apt -qq install git-lfs
!git config --global credential.helper store

ํ•™์Šต ๊ตฌ์„ฑ

ํŽธ์˜๋ฅผ ์œ„ํ•ด ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค์„ ํฌํ•จํ•œ TrainingConfig ํด๋ž˜์Šค๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค (์ž์œ ๋กญ๊ฒŒ ์กฐ์ • ๊ฐ€๋Šฅ):

>>> from dataclasses import dataclass


>>> @dataclass
... class TrainingConfig:
...     image_size = 128  # ์ƒ์„ฑ๋˜๋Š” ์ด๋ฏธ์ง€ ํ•ด์ƒ๋„
...     train_batch_size = 16
...     eval_batch_size = 16  # ํ‰๊ฐ€ ๋™์•ˆ์— ์ƒ˜ํ”Œ๋งํ•  ์ด๋ฏธ์ง€ ์ˆ˜
...     num_epochs = 50
...     gradient_accumulation_steps = 1
...     learning_rate = 1e-4
...     lr_warmup_steps = 500
...     save_image_epochs = 10
...     save_model_epochs = 30
...     mixed_precision = "fp16"  # `no`๋Š” float32, ์ž๋™ ํ˜ผํ•ฉ ์ •๋ฐ€๋„๋ฅผ ์œ„ํ•œ `fp16`
...     output_dir = "ddpm-butterflies-128"  # ๋กœ์ปฌ ๋ฐ HF Hub์— ์ €์žฅ๋˜๋Š” ๋ชจ๋ธ๋ช…

...     push_to_hub = True  # ์ €์žฅ๋œ ๋ชจ๋ธ์„ HF Hub์— ์—…๋กœ๋“œํ• ์ง€ ์—ฌ๋ถ€
...     hub_private_repo = False
...     overwrite_output_dir = True  # ๋…ธํŠธ๋ถ์„ ๋‹ค์‹œ ์‹คํ–‰ํ•  ๋•Œ ์ด์ „ ๋ชจ๋ธ์— ๋ฎ์–ด์”Œ์šธ์ง€
...     seed = 0


>>> config = TrainingConfig()

๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

๐Ÿค— Datasets ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์™€ Smithsonian Butterflies ๋ฐ์ดํ„ฐ์…‹์„ ์‰ฝ๊ฒŒ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> from datasets import load_dataset

>>> config.dataset_name = "huggan/smithsonian_butterflies_subset"
>>> dataset = load_dataset(config.dataset_name, split="train")

๐Ÿ’กHugGan Community Event ์—์„œ ์ถ”๊ฐ€์˜ ๋ฐ์ดํ„ฐ์…‹์„ ์ฐพ๊ฑฐ๋‚˜ ๋กœ์ปฌ์˜ ImageFolder๋ฅผ ๋งŒ๋“ฆ์œผ๋กœ์จ ๋‚˜๋งŒ์˜ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. HugGan Community Event ์— ๊ฐ€์ ธ์˜จ ๋ฐ์ดํ„ฐ์…‹์˜ ๊ฒฝ์šฐ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์˜ id๋กœ config.dataset_name ์„ ์„ค์ •ํ•˜๊ณ , ๋‚˜๋งŒ์˜ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ imagefolder ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

๐Ÿค— Datasets์€ [~datasets.Image] ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•ด ์ž๋™์œผ๋กœ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ๋””์ฝ”๋”ฉํ•˜๊ณ  PIL.Image๋กœ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. ์ด๋ฅผ ์‹œ๊ฐํ™” ํ•ด๋ณด๋ฉด:

>>> import matplotlib.pyplot as plt

>>> fig, axs = plt.subplots(1, 4, figsize=(16, 4))
>>> for i, image in enumerate(dataset[:4]["image"]):
...     axs[i].imshow(image)
...     axs[i].set_axis_off()
>>> fig.show()

์ด๋ฏธ์ง€๋Š” ๋ชจ๋‘ ๋‹ค๋ฅธ ์‚ฌ์ด์ฆˆ์ด๊ธฐ ๋•Œ๋ฌธ์—, ์šฐ์„  ์ „์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค:

  • Resize ๋Š” config.image_size ์— ์ •์˜๋œ ์ด๋ฏธ์ง€ ์‚ฌ์ด์ฆˆ๋กœ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค.
  • RandomHorizontalFlip ์€ ๋žœ๋ค์ ์œผ๋กœ ์ด๋ฏธ์ง€๋ฅผ ๋ฏธ๋Ÿฌ๋งํ•˜์—ฌ ๋ฐ์ดํ„ฐ์…‹์„ ๋ณด๊ฐ•ํ•ฉ๋‹ˆ๋‹ค.
  • Normalize ๋Š” ๋ชจ๋ธ์ด ์˜ˆ์ƒํ•˜๋Š” [-1, 1] ๋ฒ”์œ„๋กœ ํ”ฝ์…€ ๊ฐ’์„ ์žฌ์กฐ์ • ํ•˜๋Š”๋ฐ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.
>>> from torchvision import transforms

>>> preprocess = transforms.Compose(
...     [
...         transforms.Resize((config.image_size, config.image_size)),
...         transforms.RandomHorizontalFlip(),
...         transforms.ToTensor(),
...         transforms.Normalize([0.5], [0.5]),
...     ]
... )

ํ•™์Šต ๋„์ค‘์— preprocess ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜๋ ค๋ฉด ๐Ÿค— Datasets์˜ [~datasets.Dataset.set_transform] ๋ฐฉ๋ฒ•์ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

>>> def transform(examples):
...     images = [preprocess(image.convert("RGB")) for image in examples["image"]]
...     return {"images": images}


>>> dataset.set_transform(transform)

์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ๊ฐ€ ์กฐ์ •๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด ์ด๋ฏธ์ง€๋ฅผ ๋‹ค์‹œ ์‹œ๊ฐํ™”ํ•ด๋ณด์„ธ์š”. ์ด์ œ DataLoader์— ๋ฐ์ดํ„ฐ์…‹์„ ํฌํ•จํ•ด ํ•™์Šตํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!

>>> import torch

>>> train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

UNet2DModel ์ƒ์„ฑํ•˜๊ธฐ

๐Ÿงจ Diffusers์— ์‚ฌ์ „ํ•™์Šต๋œ ๋ชจ๋ธ๋“ค์€ ๋ชจ๋ธ ํด๋ž˜์Šค์—์„œ ์›ํ•˜๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์‰ฝ๊ฒŒ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, [UNet2DModel]๋ฅผ ์ƒ์„ฑํ•˜๋ ค๋ฉด:

>>> from diffusers import UNet2DModel

>>> model = UNet2DModel(
...     sample_size=config.image_size,  # ํƒ€๊ฒŸ ์ด๋ฏธ์ง€ ํ•ด์ƒ๋„
...     in_channels=3,  # ์ž…๋ ฅ ์ฑ„๋„ ์ˆ˜, RGB ์ด๋ฏธ์ง€์—์„œ 3
...     out_channels=3,  # ์ถœ๋ ฅ ์ฑ„๋„ ์ˆ˜
...     layers_per_block=2,  # UNet ๋ธ”๋Ÿญ๋‹น ๋ช‡ ๊ฐœ์˜ ResNet ๋ ˆ์ด์–ด๊ฐ€ ์‚ฌ์šฉ๋˜๋Š”์ง€
...     block_out_channels=(128, 128, 256, 256, 512, 512),  # ๊ฐ UNet ๋ธ”๋Ÿญ์„ ์œ„ํ•œ ์ถœ๋ ฅ ์ฑ„๋„ ์ˆ˜
...     down_block_types=(
...         "DownBlock2D",  # ์ผ๋ฐ˜์ ์ธ ResNet ๋‹ค์šด์ƒ˜ํ”Œ๋ง ๋ธ”๋Ÿญ
...         "DownBlock2D",
...         "DownBlock2D",
...         "DownBlock2D",
...         "AttnDownBlock2D",  # spatial self-attention์ด ํฌํ•จ๋œ ์ผ๋ฐ˜์ ์ธ ResNet ๋‹ค์šด์ƒ˜ํ”Œ๋ง ๋ธ”๋Ÿญ
...         "DownBlock2D",
...     ),
...     up_block_types=(
...         "UpBlock2D",  # ์ผ๋ฐ˜์ ์ธ ResNet ์—…์ƒ˜ํ”Œ๋ง ๋ธ”๋Ÿญ
...         "AttnUpBlock2D",  # spatial self-attention์ด ํฌํ•จ๋œ ์ผ๋ฐ˜์ ์ธ ResNet ์—…์ƒ˜ํ”Œ๋ง ๋ธ”๋Ÿญ
...         "UpBlock2D",
...         "UpBlock2D",
...         "UpBlock2D",
...         "UpBlock2D",
...     ),
... )

์ƒ˜ํ”Œ์˜ ์ด๋ฏธ์ง€ ํฌ๊ธฐ์™€ ๋ชจ๋ธ ์ถœ๋ ฅ ํฌ๊ธฐ๊ฐ€ ๋งž๋Š”์ง€ ๋น ๋ฅด๊ฒŒ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•œ ์ข‹์€ ์•„์ด๋””์–ด๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> sample_image = dataset[0]["images"].unsqueeze(0)
>>> print("Input shape:", sample_image.shape)
Input shape: torch.Size([1, 3, 128, 128])

>>> print("Output shape:", model(sample_image, timestep=0).sample.shape)
Output shape: torch.Size([1, 3, 128, 128])

ํ›Œ๋ฅญํ•ด์š”! ๋‹ค์Œ, ์ด๋ฏธ์ง€์— ์•ฝ๊ฐ„์˜ ๋…ธ์ด์ฆˆ๋ฅผ ๋”ํ•˜๊ธฐ ์œ„ํ•ด ์Šค์ผ€์ค„๋Ÿฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์Šค์ผ€์ค„๋Ÿฌ ์ƒ์„ฑํ•˜๊ธฐ

์Šค์ผ€์ค„๋Ÿฌ๋Š” ๋ชจ๋ธ์„ ํ•™์Šต ๋˜๋Š” ์ถ”๋ก ์— ์‚ฌ์šฉํ•˜๋Š”์ง€์— ๋”ฐ๋ผ ๋‹ค๋ฅด๊ฒŒ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. ์ถ”๋ก ์‹œ์—, ์Šค์ผ€์ค„๋Ÿฌ๋Š” ๋…ธ์ด์ฆˆ๋กœ๋ถ€ํ„ฐ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ํ•™์Šต์‹œ ์Šค์ผ€์ค„๋Ÿฌ๋Š” diffusion ๊ณผ์ •์—์„œ์˜ ํŠน์ • ํฌ์ธํŠธ๋กœ๋ถ€ํ„ฐ ๋ชจ๋ธ์˜ ์ถœ๋ ฅ ๋˜๋Š” ์ƒ˜ํ”Œ์„ ๊ฐ€์ ธ์™€ ๋…ธ์ด์ฆˆ ์Šค์ผ€์ค„ ๊ณผ ์—…๋ฐ์ดํŠธ ๊ทœ์น™์— ๋”ฐ๋ผ ์ด๋ฏธ์ง€์— ๋…ธ์ด์ฆˆ๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.

DDPMScheduler๋ฅผ ๋ณด๋ฉด ์ด์ „์œผ๋กœ๋ถ€ํ„ฐ sample_image์— ๋žœ๋คํ•œ ๋…ธ์ด์ฆˆ๋ฅผ ๋”ํ•˜๋Š” add_noise ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค:

>>> import torch
>>> from PIL import Image
>>> from diffusers import DDPMScheduler

>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
>>> noise = torch.randn(sample_image.shape)
>>> timesteps = torch.LongTensor([50])
>>> noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)

>>> Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])

๋ชจ๋ธ์˜ ํ•™์Šต ๋ชฉ์ ์€ ์ด๋ฏธ์ง€์— ๋”ํ•ด์ง„ ๋…ธ์ด์ฆˆ๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ๋‹จ๊ณ„์—์„œ ์†์‹ค์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ณ„์‚ฐ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> import torch.nn.functional as F

>>> noise_pred = model(noisy_image, timesteps).sample
>>> loss = F.mse_loss(noise_pred, noise)

๋ชจ๋ธ ํ•™์Šตํ•˜๊ธฐ

์ง€๊ธˆ๊นŒ์ง€, ๋ชจ๋ธ ํ•™์Šต์„ ์‹œ์ž‘ํ•˜๊ธฐ ์œ„ํ•ด ๋งŽ์€ ๋ถ€๋ถ„์„ ๊ฐ–์ถ”์—ˆ์œผ๋ฉฐ ์ด์ œ ๋‚จ์€ ๊ฒƒ์€ ๋ชจ๋“  ๊ฒƒ์„ ์กฐํ•ฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์šฐ์„  ์˜ตํ‹ฐ๋งˆ์ด์ €(optimizer)์™€ ํ•™์Šต๋ฅ  ์Šค์ผ€์ค„๋Ÿฌ(learning rate scheduler)๊ฐ€ ํ•„์š”ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค:

>>> from diffusers.optimization import get_cosine_schedule_with_warmup

>>> optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
>>> lr_scheduler = get_cosine_schedule_with_warmup(
...     optimizer=optimizer,
...     num_warmup_steps=config.lr_warmup_steps,
...     num_training_steps=(len(train_dataloader) * config.num_epochs),
... )

๊ทธ ํ›„, ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐฉ๋ฒ•์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด, DDPMPipeline์„ ์‚ฌ์šฉํ•ด ๋ฐฐ์น˜์˜ ์ด๋ฏธ์ง€ ์ƒ˜ํ”Œ๋“ค์„ ์ƒ์„ฑํ•˜๊ณ  ๊ทธ๋ฆฌ๋“œ ํ˜•ํƒœ๋กœ ์ €์žฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> from diffusers import DDPMPipeline
>>> import math
>>> import os


>>> def make_grid(images, rows, cols):
...     w, h = images[0].size
...     grid = Image.new("RGB", size=(cols * w, rows * h))
...     for i, image in enumerate(images):
...         grid.paste(image, box=(i % cols * w, i // cols * h))
...     return grid


>>> def evaluate(config, epoch, pipeline):
...     # ๋žœ๋คํ•œ ๋…ธ์ด์ฆˆ๋กœ ๋ถ€ํ„ฐ ์ด๋ฏธ์ง€๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.(์ด๋Š” ์—ญ์ „ํŒŒ diffusion ๊ณผ์ •์ž…๋‹ˆ๋‹ค.)
...     # ๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ ์ถœ๋ ฅ ํ˜•ํƒœ๋Š” `List[PIL.Image]` ์ž…๋‹ˆ๋‹ค.
...     images = pipeline(
...         batch_size=config.eval_batch_size,
...         generator=torch.manual_seed(config.seed),
...     ).images

...     # ์ด๋ฏธ์ง€๋“ค์„ ๊ทธ๋ฆฌ๋“œ๋กœ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.
...     image_grid = make_grid(images, rows=4, cols=4)

...     # ์ด๋ฏธ์ง€๋“ค์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
...     test_dir = os.path.join(config.output_dir, "samples")
...     os.makedirs(test_dir, exist_ok=True)
...     image_grid.save(f"{test_dir}/{epoch:04d}.png")

TensorBoard์— ๋กœ๊น…, ๊ทธ๋ž˜๋””์–ธํŠธ ๋ˆ„์  ๋ฐ ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ํ•™์Šต์„ ์‰ฝ๊ฒŒ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ๐Ÿค— Accelerate๋ฅผ ํ•™์Šต ๋ฃจํ”„์— ํ•จ๊ป˜ ์•ž์„œ ๋งํ•œ ๋ชจ๋“  ๊ตฌ์„ฑ ์ •๋ณด๋“ค์„ ๋ฌถ์–ด ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ—ˆ๋ธŒ์— ๋ชจ๋ธ์„ ์—…๋กœ๋“œ ํ•˜๊ธฐ ์œ„ํ•ด ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ์ด๋ฆ„ ๋ฐ ์ •๋ณด๋ฅผ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜๋ฅผ ์ž‘์„ฑํ•˜๊ณ  ํ—ˆ๋ธŒ์— ์—…๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๐Ÿ’ก์•„๋ž˜์˜ ํ•™์Šต ๋ฃจํ”„๋Š” ์–ด๋ ต๊ณ  ๊ธธ์–ด ๋ณด์ผ ์ˆ˜ ์žˆ์ง€๋งŒ, ๋‚˜์ค‘์— ํ•œ ์ค„์˜ ์ฝ”๋“œ๋กœ ํ•™์Šต์„ ํ•œ๋‹ค๋ฉด ๊ทธ๋งŒํ•œ ๊ฐ€์น˜๊ฐ€ ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค! ๋งŒ์•ฝ ๊ธฐ๋‹ค๋ฆฌ์ง€ ๋ชปํ•˜๊ณ  ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด, ์•„๋ž˜ ์ฝ”๋“œ๋ฅผ ์ž์œ ๋กญ๊ฒŒ ๋ถ™์—ฌ๋„ฃ๊ณ  ์ž‘๋™์‹œํ‚ค๋ฉด ๋ฉ๋‹ˆ๋‹ค. ๐Ÿค—

>>> from accelerate import Accelerator
>>> from huggingface_hub import create_repo, upload_folder
>>> from tqdm.auto import tqdm
>>> from pathlib import Path
>>> import os


>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
...     # Initialize accelerator and tensorboard logging
...     accelerator = Accelerator(
...         mixed_precision=config.mixed_precision,
...         gradient_accumulation_steps=config.gradient_accumulation_steps,
...         log_with="tensorboard",
...         project_dir=os.path.join(config.output_dir, "logs"),
...     )
...     if accelerator.is_main_process:
...         if config.output_dir is not None:
...             os.makedirs(config.output_dir, exist_ok=True)
...         if config.push_to_hub:
...             repo_id = create_repo(
...                 repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
...             ).repo_id
...         accelerator.init_trackers("train_example")

...     # ๋ชจ๋“  ๊ฒƒ์ด ์ค€๋น„๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
...     # ๊ธฐ์–ตํ•ด์•ผ ํ•  ํŠน์ •ํ•œ ์ˆœ์„œ๋Š” ์—†์œผ๋ฉฐ ์ค€๋น„ํ•œ ๋ฐฉ๋ฒ•์— ์ œ๊ณตํ•œ ๊ฒƒ๊ณผ ๋™์ผํ•œ ์ˆœ์„œ๋กœ ๊ฐ์ฒด์˜ ์••์ถ•์„ ํ’€๋ฉด ๋ฉ๋‹ˆ๋‹ค.
...     model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
...         model, optimizer, train_dataloader, lr_scheduler
...     )

...     global_step = 0

...     # ์ด์ œ ๋ชจ๋ธ์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
...     for epoch in range(config.num_epochs):
...         progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
...         progress_bar.set_description(f"Epoch {epoch}")

...         for step, batch in enumerate(train_dataloader):
...             clean_images = batch["images"]
...             # ์ด๋ฏธ์ง€์— ๋”ํ•  ๋…ธ์ด์ฆˆ๋ฅผ ์ƒ˜ํ”Œ๋งํ•ฉ๋‹ˆ๋‹ค.
...             noise = torch.randn(clean_images.shape, device=clean_images.device)
...             bs = clean_images.shape[0]

...             # ๊ฐ ์ด๋ฏธ์ง€๋ฅผ ์œ„ํ•œ ๋žœ๋คํ•œ ํƒ€์ž„์Šคํ…(timestep)์„ ์ƒ˜ํ”Œ๋งํ•ฉ๋‹ˆ๋‹ค.
...             timesteps = torch.randint(
...                 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
...                 dtype=torch.int64
...             )

...             # ๊ฐ ํƒ€์ž„์Šคํ…์˜ ๋…ธ์ด์ฆˆ ํฌ๊ธฐ์— ๋”ฐ๋ผ ๊นจ๋—ํ•œ ์ด๋ฏธ์ง€์— ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
...             # (์ด๋Š” foward diffusion ๊ณผ์ •์ž…๋‹ˆ๋‹ค.)
...             noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

...             with accelerator.accumulate(model):
...                 # ๋…ธ์ด์ฆˆ๋ฅผ ๋ฐ˜๋ณต์ ์œผ๋กœ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
...                 noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
...                 loss = F.mse_loss(noise_pred, noise)
...                 accelerator.backward(loss)

...                 accelerator.clip_grad_norm_(model.parameters(), 1.0)
...                 optimizer.step()
...                 lr_scheduler.step()
...                 optimizer.zero_grad()

...             progress_bar.update(1)
...             logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
...             progress_bar.set_postfix(**logs)
...             accelerator.log(logs, step=global_step)
...             global_step += 1

...         # ๊ฐ ์—ํฌํฌ๊ฐ€ ๋๋‚œ ํ›„ evaluate()์™€ ๋ช‡ ๊ฐ€์ง€ ๋ฐ๋ชจ ์ด๋ฏธ์ง€๋ฅผ ์„ ํƒ์ ์œผ๋กœ ์ƒ˜ํ”Œ๋งํ•˜๊ณ  ๋ชจ๋ธ์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
...         if accelerator.is_main_process:
...             pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

...             if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
...                 evaluate(config, epoch, pipeline)

...             if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
...                 if config.push_to_hub:
...                     upload_folder(
...                         repo_id=repo_id,
...                         folder_path=config.output_dir,
...                         commit_message=f"Epoch {epoch}",
...                         ignore_patterns=["step_*", "epoch_*"],
...                     )
...                 else:
...                     pipeline.save_pretrained(config.output_dir)

ํœด, ์ฝ”๋“œ๊ฐ€ ๊ฝค ๋งŽ์•˜๋„ค์š”! ํ•˜์ง€๋งŒ ๐Ÿค— Accelerate์˜ [~accelerate.notebook_launcher] ํ•จ์ˆ˜์™€ ํ•™์Šต์„ ์‹œ์ž‘ํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ํ•จ์ˆ˜์— ํ•™์Šต ๋ฃจํ”„, ๋ชจ๋“  ํ•™์Šต ์ธ์ˆ˜, ํ•™์Šต์— ์‚ฌ์šฉํ•  ํ”„๋กœ์„ธ์Šค ์ˆ˜(์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU์˜ ์ˆ˜๋ฅผ ๋ณ€๊ฒฝํ•  ์ˆ˜ ์žˆ์Œ)๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค:

>>> from accelerate import notebook_launcher

>>> args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

>>> notebook_launcher(train_loop, args, num_processes=1)

ํ•œ๋ฒˆ ํ•™์Šต์ด ์™„๋ฃŒ๋˜๋ฉด, diffusion ๋ชจ๋ธ๋กœ ์ƒ์„ฑ๋œ ์ตœ์ข… ๐Ÿฆ‹์ด๋ฏธ์ง€๐Ÿฆ‹๋ฅผ ํ™•์ธํ•ด๋ณด๊ธธ ๋ฐ”๋ž๋‹ˆ๋‹ค!

>>> import glob

>>> sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
>>> Image.open(sample_images[-1])

๋‹ค์Œ ๋‹จ๊ณ„

Unconditional ์ด๋ฏธ์ง€ ์ƒ์„ฑ์€ ํ•™์Šต๋  ์ˆ˜ ์žˆ๋Š” ์ž‘์—… ์ค‘ ํ•˜๋‚˜์˜ ์˜ˆ์‹œ์ž…๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ์ž‘์—…๊ณผ ํ•™์Šต ๋ฐฉ๋ฒ•์€ ๐Ÿงจ Diffusers ํ•™์Šต ์˜ˆ์‹œ ํŽ˜์ด์ง€์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ์€ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋Š” ๋ช‡ ๊ฐ€์ง€ ์˜ˆ์‹œ์ž…๋‹ˆ๋‹ค:

  • Textual Inversion, ํŠน์ • ์‹œ๊ฐ์  ๊ฐœ๋…์„ ํ•™์Šต์‹œ์ผœ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€์— ํ†ตํ•ฉ์‹œํ‚ค๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์ž…๋‹ˆ๋‹ค.
  • DreamBooth, ์ฃผ์ œ์— ๋Œ€ํ•œ ๋ช‡ ๊ฐ€์ง€ ์ž…๋ ฅ ์ด๋ฏธ์ง€๋“ค์ด ์ฃผ์–ด์ง€๋ฉด ์ฃผ์ œ์— ๋Œ€ํ•œ ๊ฐœ์ธํ™”๋œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•œ ๊ธฐ์ˆ ์ž…๋‹ˆ๋‹ค.
  • Guide ๋ฐ์ดํ„ฐ์…‹์— Stable Diffusion ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
  • Guide LoRA๋ฅผ ์‚ฌ์šฉํ•ด ๋งค์šฐ ํฐ ๋ชจ๋ธ์„ ๋น ๋ฅด๊ฒŒ ํŒŒ์ธํŠœ๋‹ํ•˜๊ธฐ ์œ„ํ•œ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ ์ธ ๊ธฐ์ˆ ์ž…๋‹ˆ๋‹ค.