mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
402 lines
18 KiB
Markdown
402 lines
18 KiB
Markdown
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
[[open-in-colab]]
|
|
|
|
|
|
# Diffusion ๋ชจ๋ธ์ ํ์ตํ๊ธฐ
|
|
|
|
Unconditional ์ด๋ฏธ์ง ์์ฑ์ ํ์ต์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ์
๊ณผ ์ ์ฌํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ diffusion ๋ชจ๋ธ์์ ์ธ๊ธฐ ์๋ ์ดํ๋ฆฌ์ผ์ด์
์
๋๋ค. ์ผ๋ฐ์ ์ผ๋ก, ๊ฐ์ฅ ์ข์ ๊ฒฐ๊ณผ๋ ํน์ ๋ฐ์ดํฐ์
์ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ํ์ธํ๋ํ๋ ๊ฒ์ผ๋ก ์ป์ ์ ์์ต๋๋ค. ์ด [ํ๋ธ](https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model)์์ ์ด๋ฌํ ๋ง์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ์ ์ ์์ง๋ง, ๋ง์ฝ ๋ง์์ ๋๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ์ง ๋ชปํ๋ค๋ฉด, ์ธ์ ๋ ์ง ์ค์ค๋ก ํ์ตํ ์ ์์ต๋๋ค!
|
|
|
|
์ด ํํ ๋ฆฌ์ผ์ ๋๋ง์ ๐ฆ ๋๋น ๐ฆ๋ฅผ ์์ฑํ๊ธฐ ์ํด [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) ๋ฐ์ดํฐ์
์ ํ์ ์งํฉ์์ [`UNet2DModel`] ๋ชจ๋ธ์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ๊ฐ๋ฅด์ณ์ค ๊ฒ์
๋๋ค.
|
|
|
|
<Tip>
|
|
|
|
๐ก ์ด ํ์ต ํํ ๋ฆฌ์ผ์ [Training with ๐งจ Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) ๋
ธํธ๋ถ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค. Diffusion ๋ชจ๋ธ์ ์๋ ๋ฐฉ์ ๋ฐ ์์ธํ ๋ด์ฉ์ ๋
ธํธ๋ถ์ ํ์ธํ์ธ์!
|
|
|
|
</Tip>
|
|
|
|
์์ ์ ์, ๐ค Datasets์ ๋ถ๋ฌ์ค๊ณ ์ ์ฒ๋ฆฌํ๊ธฐ ์ํด ๋ฐ์ดํฐ์
์ด ์ค์น๋์ด ์๋์ง ๋ค์ GPU์์ ํ์ต์ ๊ฐ์ํํ๊ธฐ ์ํด ๐ค Accelerate ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์. ๊ทธ ํ ํ์ต ๋ฉํธ๋ฆญ์ ์๊ฐํํ๊ธฐ ์ํด [TensorBoard](https://www.tensorflow.org/tensorboard)๋ฅผ ๋ํ ์ค์นํ์ธ์. (๋ํ ํ์ต ์ถ์ ์ ์ํด [Weights & Biases](https://docs.wandb.ai/)๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.)
|
|
|
|
```bash
|
|
!pip install diffusers[training]
|
|
```
|
|
|
|
์ปค๋ฎค๋ํฐ์ ๋ชจ๋ธ์ ๊ณต์ ํ ๊ฒ์ ๊ถ์ฅํ๋ฉฐ, ์ด๋ฅผ ์ํด์ Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธ์ ํด์ผ ํฉ๋๋ค. (๊ณ์ ์ด ์๋ค๋ฉด [์ฌ๊ธฐ](https://hf.co/join)์์ ๋ง๋ค ์ ์์ต๋๋ค.) ๋
ธํธ๋ถ์์ ๋ก๊ทธ์ธํ ์ ์์ผ๋ฉฐ ๋ฉ์์ง๊ฐ ํ์๋๋ฉด ํ ํฐ์ ์
๋ ฅํ ์ ์์ต๋๋ค.
|
|
|
|
```py
|
|
>>> from huggingface_hub import notebook_login
|
|
|
|
>>> notebook_login()
|
|
```
|
|
|
|
๋๋ ํฐ๋ฏธ๋๋ก ๋ก๊ทธ์ธํ ์ ์์ต๋๋ค:
|
|
|
|
```bash
|
|
huggingface-cli login
|
|
```
|
|
|
|
๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์๋นํ ํฌ๊ธฐ ๋๋ฌธ์ [Git-LFS](https://git-lfs.com/)์์ ๋์ฉ๋ ํ์ผ์ ๋ฒ์ ๊ด๋ฆฌ๋ฅผ ํ ์ ์์ต๋๋ค.
|
|
|
|
```bash
|
|
!sudo apt -qq install git-lfs
|
|
!git config --global credential.helper store
|
|
```
|
|
|
|
|
|
## ํ์ต ๊ตฌ์ฑ
|
|
|
|
ํธ์๋ฅผ ์ํด ํ์ต ํ๋ผ๋ฏธํฐ๋ค์ ํฌํจํ `TrainingConfig` ํด๋์ค๋ฅผ ์์ฑํฉ๋๋ค (์์ ๋กญ๊ฒ ์กฐ์ ๊ฐ๋ฅ):
|
|
|
|
```py
|
|
>>> from dataclasses import dataclass
|
|
|
|
|
|
>>> @dataclass
|
|
... class TrainingConfig:
|
|
... image_size = 128 # ์์ฑ๋๋ ์ด๋ฏธ์ง ํด์๋
|
|
... train_batch_size = 16
|
|
... eval_batch_size = 16 # ํ๊ฐ ๋์์ ์ํ๋งํ ์ด๋ฏธ์ง ์
|
|
... num_epochs = 50
|
|
... gradient_accumulation_steps = 1
|
|
... learning_rate = 1e-4
|
|
... lr_warmup_steps = 500
|
|
... save_image_epochs = 10
|
|
... save_model_epochs = 30
|
|
... mixed_precision = "fp16" # `no`๋ float32, ์๋ ํผํฉ ์ ๋ฐ๋๋ฅผ ์ํ `fp16`
|
|
... output_dir = "ddpm-butterflies-128" # ๋ก์ปฌ ๋ฐ HF Hub์ ์ ์ฅ๋๋ ๋ชจ๋ธ๋ช
|
|
|
|
... push_to_hub = True # ์ ์ฅ๋ ๋ชจ๋ธ์ HF Hub์ ์
๋ก๋ํ ์ง ์ฌ๋ถ
|
|
... hub_private_repo = False
|
|
... overwrite_output_dir = True # ๋
ธํธ๋ถ์ ๋ค์ ์คํํ ๋ ์ด์ ๋ชจ๋ธ์ ๋ฎ์ด์์ธ์ง
|
|
... seed = 0
|
|
|
|
|
|
>>> config = TrainingConfig()
|
|
```
|
|
|
|
|
|
## ๋ฐ์ดํฐ์
๋ถ๋ฌ์ค๊ธฐ
|
|
|
|
๐ค Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) ๋ฐ์ดํฐ์
์ ์ฝ๊ฒ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค.
|
|
|
|
```py
|
|
>>> from datasets import load_dataset
|
|
|
|
>>> config.dataset_name = "huggan/smithsonian_butterflies_subset"
|
|
>>> dataset = load_dataset(config.dataset_name, split="train")
|
|
```
|
|
|
|
๐ก[HugGan Community Event](https://huggingface.co/huggan) ์์ ์ถ๊ฐ์ ๋ฐ์ดํฐ์
์ ์ฐพ๊ฑฐ๋ ๋ก์ปฌ์ [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder)๋ฅผ ๋ง๋ฆ์ผ๋ก์จ ๋๋ง์ ๋ฐ์ดํฐ์
์ ์ฌ์ฉํ ์ ์์ต๋๋ค. HugGan Community Event ์ ๊ฐ์ ธ์จ ๋ฐ์ดํฐ์
์ ๊ฒฝ์ฐ ๋ฆฌํฌ์งํ ๋ฆฌ์ id๋ก `config.dataset_name` ์ ์ค์ ํ๊ณ , ๋๋ง์ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ `imagefolder` ๋ฅผ ์ค์ ํฉ๋๋ค.
|
|
|
|
๐ค Datasets์ [`~datasets.Image`] ๊ธฐ๋ฅ์ ์ฌ์ฉํด ์๋์ผ๋ก ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ๋์ฝ๋ฉํ๊ณ [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html)๋ก ๋ถ๋ฌ์ต๋๋ค. ์ด๋ฅผ ์๊ฐํ ํด๋ณด๋ฉด:
|
|
|
|
```py
|
|
>>> import matplotlib.pyplot as plt
|
|
|
|
>>> fig, axs = plt.subplots(1, 4, figsize=(16, 4))
|
|
>>> for i, image in enumerate(dataset[:4]["image"]):
|
|
... axs[i].imshow(image)
|
|
... axs[i].set_axis_off()
|
|
>>> fig.show()
|
|
```
|
|
|
|

|
|
|
|
์ด๋ฏธ์ง๋ ๋ชจ๋ ๋ค๋ฅธ ์ฌ์ด์ฆ์ด๊ธฐ ๋๋ฌธ์, ์ฐ์ ์ ์ฒ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค:
|
|
|
|
- `Resize` ๋ `config.image_size` ์ ์ ์๋ ์ด๋ฏธ์ง ์ฌ์ด์ฆ๋ก ๋ณ๊ฒฝํฉ๋๋ค.
|
|
- `RandomHorizontalFlip` ์ ๋๋ค์ ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ๋ฏธ๋ฌ๋งํ์ฌ ๋ฐ์ดํฐ์
์ ๋ณด๊ฐํฉ๋๋ค.
|
|
- `Normalize` ๋ ๋ชจ๋ธ์ด ์์ํ๋ [-1, 1] ๋ฒ์๋ก ํฝ์
๊ฐ์ ์ฌ์กฐ์ ํ๋๋ฐ ์ค์ํฉ๋๋ค.
|
|
|
|
```py
|
|
>>> from torchvision import transforms
|
|
|
|
>>> preprocess = transforms.Compose(
|
|
... [
|
|
... transforms.Resize((config.image_size, config.image_size)),
|
|
... transforms.RandomHorizontalFlip(),
|
|
... transforms.ToTensor(),
|
|
... transforms.Normalize([0.5], [0.5]),
|
|
... ]
|
|
... )
|
|
```
|
|
|
|
ํ์ต ๋์ค์ `preprocess` ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets์ [`~datasets.Dataset.set_transform`] ๋ฐฉ๋ฒ์ด ์ฌ์ฉ๋ฉ๋๋ค.
|
|
|
|
```py
|
|
>>> def transform(examples):
|
|
... images = [preprocess(image.convert("RGB")) for image in examples["image"]]
|
|
... return {"images": images}
|
|
|
|
|
|
>>> dataset.set_transform(transform)
|
|
```
|
|
|
|
์ด๋ฏธ์ง์ ํฌ๊ธฐ๊ฐ ์กฐ์ ๋์๋์ง ํ์ธํ๊ธฐ ์ํด ์ด๋ฏธ์ง๋ฅผ ๋ค์ ์๊ฐํํด๋ณด์ธ์. ์ด์ [DataLoader](https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader)์ ๋ฐ์ดํฐ์
์ ํฌํจํด ํ์ตํ ์ค๋น๊ฐ ๋์์ต๋๋ค!
|
|
|
|
```py
|
|
>>> import torch
|
|
|
|
>>> train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
|
|
```
|
|
|
|
|
|
## UNet2DModel ์์ฑํ๊ธฐ
|
|
|
|
๐งจ Diffusers์ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ๋ค์ ๋ชจ๋ธ ํด๋์ค์์ ์ํ๋ ํ๋ผ๋ฏธํฐ๋ก ์ฝ๊ฒ ์์ฑํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, [`UNet2DModel`]๋ฅผ ์์ฑํ๋ ค๋ฉด:
|
|
|
|
```py
|
|
>>> from diffusers import UNet2DModel
|
|
|
|
>>> model = UNet2DModel(
|
|
... sample_size=config.image_size, # ํ๊ฒ ์ด๋ฏธ์ง ํด์๋
|
|
... in_channels=3, # ์
๋ ฅ ์ฑ๋ ์, RGB ์ด๋ฏธ์ง์์ 3
|
|
... out_channels=3, # ์ถ๋ ฅ ์ฑ๋ ์
|
|
... layers_per_block=2, # UNet ๋ธ๋ญ๋น ๋ช ๊ฐ์ ResNet ๋ ์ด์ด๊ฐ ์ฌ์ฉ๋๋์ง
|
|
... block_out_channels=(128, 128, 256, 256, 512, 512), # ๊ฐ UNet ๋ธ๋ญ์ ์ํ ์ถ๋ ฅ ์ฑ๋ ์
|
|
... down_block_types=(
|
|
... "DownBlock2D", # ์ผ๋ฐ์ ์ธ ResNet ๋ค์ด์ํ๋ง ๋ธ๋ญ
|
|
... "DownBlock2D",
|
|
... "DownBlock2D",
|
|
... "DownBlock2D",
|
|
... "AttnDownBlock2D", # spatial self-attention์ด ํฌํจ๋ ์ผ๋ฐ์ ์ธ ResNet ๋ค์ด์ํ๋ง ๋ธ๋ญ
|
|
... "DownBlock2D",
|
|
... ),
|
|
... up_block_types=(
|
|
... "UpBlock2D", # ์ผ๋ฐ์ ์ธ ResNet ์
์ํ๋ง ๋ธ๋ญ
|
|
... "AttnUpBlock2D", # spatial self-attention์ด ํฌํจ๋ ์ผ๋ฐ์ ์ธ ResNet ์
์ํ๋ง ๋ธ๋ญ
|
|
... "UpBlock2D",
|
|
... "UpBlock2D",
|
|
... "UpBlock2D",
|
|
... "UpBlock2D",
|
|
... ),
|
|
... )
|
|
```
|
|
|
|
์ํ์ ์ด๋ฏธ์ง ํฌ๊ธฐ์ ๋ชจ๋ธ ์ถ๋ ฅ ํฌ๊ธฐ๊ฐ ๋ง๋์ง ๋น ๋ฅด๊ฒ ํ์ธํ๊ธฐ ์ํ ์ข์ ์์ด๋์ด๊ฐ ์์ต๋๋ค:
|
|
|
|
```py
|
|
>>> sample_image = dataset[0]["images"].unsqueeze(0)
|
|
>>> print("Input shape:", sample_image.shape)
|
|
Input shape: torch.Size([1, 3, 128, 128])
|
|
|
|
>>> print("Output shape:", model(sample_image, timestep=0).sample.shape)
|
|
Output shape: torch.Size([1, 3, 128, 128])
|
|
```
|
|
|
|
ํ๋ฅญํด์! ๋ค์, ์ด๋ฏธ์ง์ ์ฝ๊ฐ์ ๋
ธ์ด์ฆ๋ฅผ ๋ํ๊ธฐ ์ํด ์ค์ผ์ค๋ฌ๊ฐ ํ์ํฉ๋๋ค.
|
|
|
|
|
|
## ์ค์ผ์ค๋ฌ ์์ฑํ๊ธฐ
|
|
|
|
์ค์ผ์ค๋ฌ๋ ๋ชจ๋ธ์ ํ์ต ๋๋ ์ถ๋ก ์ ์ฌ์ฉํ๋์ง์ ๋ฐ๋ผ ๋ค๋ฅด๊ฒ ์๋ํฉ๋๋ค. ์ถ๋ก ์์, ์ค์ผ์ค๋ฌ๋ ๋
ธ์ด์ฆ๋ก๋ถํฐ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค. ํ์ต์ ์ค์ผ์ค๋ฌ๋ diffusion ๊ณผ์ ์์์ ํน์ ํฌ์ธํธ๋ก๋ถํฐ ๋ชจ๋ธ์ ์ถ๋ ฅ ๋๋ ์ํ์ ๊ฐ์ ธ์ *๋
ธ์ด์ฆ ์ค์ผ์ค* ๊ณผ *์
๋ฐ์ดํธ ๊ท์น*์ ๋ฐ๋ผ ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์ ์ฉํฉ๋๋ค.
|
|
|
|
`DDPMScheduler`๋ฅผ ๋ณด๋ฉด ์ด์ ์ผ๋ก๋ถํฐ `sample_image`์ ๋๋คํ ๋
ธ์ด์ฆ๋ฅผ ๋ํ๋ `add_noise` ๋ฉ์๋๋ฅผ ์ฌ์ฉํฉ๋๋ค:
|
|
|
|
```py
|
|
>>> import torch
|
|
>>> from PIL import Image
|
|
>>> from diffusers import DDPMScheduler
|
|
|
|
>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
|
|
>>> noise = torch.randn(sample_image.shape)
|
|
>>> timesteps = torch.LongTensor([50])
|
|
>>> noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)
|
|
|
|
>>> Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])
|
|
```
|
|
|
|

|
|
|
|
๋ชจ๋ธ์ ํ์ต ๋ชฉ์ ์ ์ด๋ฏธ์ง์ ๋ํด์ง ๋
ธ์ด์ฆ๋ฅผ ์์ธกํ๋ ๊ฒ์
๋๋ค. ์ด ๋จ๊ณ์์ ์์ค์ ๋ค์๊ณผ ๊ฐ์ด ๊ณ์ฐ๋ ์ ์์ต๋๋ค:
|
|
|
|
```py
|
|
>>> import torch.nn.functional as F
|
|
|
|
>>> noise_pred = model(noisy_image, timesteps).sample
|
|
>>> loss = F.mse_loss(noise_pred, noise)
|
|
```
|
|
|
|
## ๋ชจ๋ธ ํ์ตํ๊ธฐ
|
|
|
|
์ง๊ธ๊น์ง, ๋ชจ๋ธ ํ์ต์ ์์ํ๊ธฐ ์ํด ๋ง์ ๋ถ๋ถ์ ๊ฐ์ถ์์ผ๋ฉฐ ์ด์ ๋จ์ ๊ฒ์ ๋ชจ๋ ๊ฒ์ ์กฐํฉํ๋ ๊ฒ์
๋๋ค.
|
|
|
|
์ฐ์ ์ตํฐ๋ง์ด์ (optimizer)์ ํ์ต๋ฅ ์ค์ผ์ค๋ฌ(learning rate scheduler)๊ฐ ํ์ํ ๊ฒ์
๋๋ค:
|
|
|
|
```py
|
|
>>> from diffusers.optimization import get_cosine_schedule_with_warmup
|
|
|
|
>>> optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
|
>>> lr_scheduler = get_cosine_schedule_with_warmup(
|
|
... optimizer=optimizer,
|
|
... num_warmup_steps=config.lr_warmup_steps,
|
|
... num_training_steps=(len(train_dataloader) * config.num_epochs),
|
|
... )
|
|
```
|
|
|
|
๊ทธ ํ, ๋ชจ๋ธ์ ํ๊ฐํ๋ ๋ฐฉ๋ฒ์ด ํ์ํฉ๋๋ค. ํ๊ฐ๋ฅผ ์ํด, `DDPMPipeline`์ ์ฌ์ฉํด ๋ฐฐ์น์ ์ด๋ฏธ์ง ์ํ๋ค์ ์์ฑํ๊ณ ๊ทธ๋ฆฌ๋ ํํ๋ก ์ ์ฅํ ์ ์์ต๋๋ค:
|
|
|
|
```py
|
|
>>> from diffusers import DDPMPipeline
|
|
>>> import math
|
|
>>> import os
|
|
|
|
|
|
>>> def make_grid(images, rows, cols):
|
|
... w, h = images[0].size
|
|
... grid = Image.new("RGB", size=(cols * w, rows * h))
|
|
... for i, image in enumerate(images):
|
|
... grid.paste(image, box=(i % cols * w, i // cols * h))
|
|
... return grid
|
|
|
|
|
|
>>> def evaluate(config, epoch, pipeline):
|
|
... # ๋๋คํ ๋
ธ์ด์ฆ๋ก ๋ถํฐ ์ด๋ฏธ์ง๋ฅผ ์ถ์ถํฉ๋๋ค.(์ด๋ ์ญ์ ํ diffusion ๊ณผ์ ์
๋๋ค.)
|
|
... # ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ ์ถ๋ ฅ ํํ๋ `List[PIL.Image]` ์
๋๋ค.
|
|
... images = pipeline(
|
|
... batch_size=config.eval_batch_size,
|
|
... generator=torch.manual_seed(config.seed),
|
|
... ).images
|
|
|
|
... # ์ด๋ฏธ์ง๋ค์ ๊ทธ๋ฆฌ๋๋ก ๋ง๋ค์ด์ค๋๋ค.
|
|
... image_grid = make_grid(images, rows=4, cols=4)
|
|
|
|
... # ์ด๋ฏธ์ง๋ค์ ์ ์ฅํฉ๋๋ค.
|
|
... test_dir = os.path.join(config.output_dir, "samples")
|
|
... os.makedirs(test_dir, exist_ok=True)
|
|
... image_grid.save(f"{test_dir}/{epoch:04d}.png")
|
|
```
|
|
|
|
TensorBoard์ ๋ก๊น
, ๊ทธ๋๋์ธํธ ๋์ ๋ฐ ํผํฉ ์ ๋ฐ๋ ํ์ต์ ์ฝ๊ฒ ์ํํ๊ธฐ ์ํด ๐ค Accelerate๋ฅผ ํ์ต ๋ฃจํ์ ํจ๊ป ์์ ๋งํ ๋ชจ๋ ๊ตฌ์ฑ ์ ๋ณด๋ค์ ๋ฌถ์ด ์งํํ ์ ์์ต๋๋ค. ํ๋ธ์ ๋ชจ๋ธ์ ์
๋ก๋ ํ๊ธฐ ์ํด ๋ฆฌํฌ์งํ ๋ฆฌ ์ด๋ฆ ๋ฐ ์ ๋ณด๋ฅผ ๊ฐ์ ธ์ค๊ธฐ ์ํ ํจ์๋ฅผ ์์ฑํ๊ณ ํ๋ธ์ ์
๋ก๋ํ ์ ์์ต๋๋ค.
|
|
|
|
๐ก์๋์ ํ์ต ๋ฃจํ๋ ์ด๋ ต๊ณ ๊ธธ์ด ๋ณด์ผ ์ ์์ง๋ง, ๋์ค์ ํ ์ค์ ์ฝ๋๋ก ํ์ต์ ํ๋ค๋ฉด ๊ทธ๋งํ ๊ฐ์น๊ฐ ์์ ๊ฒ์
๋๋ค! ๋ง์ฝ ๊ธฐ๋ค๋ฆฌ์ง ๋ชปํ๊ณ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ณ ์ถ๋ค๋ฉด, ์๋ ์ฝ๋๋ฅผ ์์ ๋กญ๊ฒ ๋ถ์ฌ๋ฃ๊ณ ์๋์ํค๋ฉด ๋ฉ๋๋ค. ๐ค
|
|
|
|
```py
|
|
>>> from accelerate import Accelerator
|
|
>>> from huggingface_hub import create_repo, upload_folder
|
|
>>> from tqdm.auto import tqdm
|
|
>>> from pathlib import Path
|
|
>>> import os
|
|
|
|
|
|
>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
|
|
... # Initialize accelerator and tensorboard logging
|
|
... accelerator = Accelerator(
|
|
... mixed_precision=config.mixed_precision,
|
|
... gradient_accumulation_steps=config.gradient_accumulation_steps,
|
|
... log_with="tensorboard",
|
|
... project_dir=os.path.join(config.output_dir, "logs"),
|
|
... )
|
|
... if accelerator.is_main_process:
|
|
... if config.output_dir is not None:
|
|
... os.makedirs(config.output_dir, exist_ok=True)
|
|
... if config.push_to_hub:
|
|
... repo_id = create_repo(
|
|
... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
|
|
... ).repo_id
|
|
... accelerator.init_trackers("train_example")
|
|
|
|
... # ๋ชจ๋ ๊ฒ์ด ์ค๋น๋์์ต๋๋ค.
|
|
... # ๊ธฐ์ตํด์ผ ํ ํน์ ํ ์์๋ ์์ผ๋ฉฐ ์ค๋นํ ๋ฐฉ๋ฒ์ ์ ๊ณตํ ๊ฒ๊ณผ ๋์ผํ ์์๋ก ๊ฐ์ฒด์ ์์ถ์ ํ๋ฉด ๋ฉ๋๋ค.
|
|
... model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
... model, optimizer, train_dataloader, lr_scheduler
|
|
... )
|
|
|
|
... global_step = 0
|
|
|
|
... # ์ด์ ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค.
|
|
... for epoch in range(config.num_epochs):
|
|
... progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
|
|
... progress_bar.set_description(f"Epoch {epoch}")
|
|
|
|
... for step, batch in enumerate(train_dataloader):
|
|
... clean_images = batch["images"]
|
|
... # ์ด๋ฏธ์ง์ ๋ํ ๋
ธ์ด์ฆ๋ฅผ ์ํ๋งํฉ๋๋ค.
|
|
... noise = torch.randn(clean_images.shape).to(clean_images.device)
|
|
... bs = clean_images.shape[0]
|
|
|
|
... # ๊ฐ ์ด๋ฏธ์ง๋ฅผ ์ํ ๋๋คํ ํ์์คํ
(timestep)์ ์ํ๋งํฉ๋๋ค.
|
|
... timesteps = torch.randint(
|
|
... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
|
|
... ).long()
|
|
|
|
... # ๊ฐ ํ์์คํ
์ ๋
ธ์ด์ฆ ํฌ๊ธฐ์ ๋ฐ๋ผ ๊นจ๋ํ ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์ถ๊ฐํฉ๋๋ค.
|
|
... # (์ด๋ foward diffusion ๊ณผ์ ์
๋๋ค.)
|
|
... noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
|
|
|
|
... with accelerator.accumulate(model):
|
|
... # ๋
ธ์ด์ฆ๋ฅผ ๋ฐ๋ณต์ ์ผ๋ก ์์ธกํฉ๋๋ค.
|
|
... noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
|
|
... loss = F.mse_loss(noise_pred, noise)
|
|
... accelerator.backward(loss)
|
|
|
|
... accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
|
... optimizer.step()
|
|
... lr_scheduler.step()
|
|
... optimizer.zero_grad()
|
|
|
|
... progress_bar.update(1)
|
|
... logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
|
... progress_bar.set_postfix(**logs)
|
|
... accelerator.log(logs, step=global_step)
|
|
... global_step += 1
|
|
|
|
... # ๊ฐ ์ํฌํฌ๊ฐ ๋๋ ํ evaluate()์ ๋ช ๊ฐ์ง ๋ฐ๋ชจ ์ด๋ฏธ์ง๋ฅผ ์ ํ์ ์ผ๋ก ์ํ๋งํ๊ณ ๋ชจ๋ธ์ ์ ์ฅํฉ๋๋ค.
|
|
... if accelerator.is_main_process:
|
|
... pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
|
|
|
... if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
|
|
... evaluate(config, epoch, pipeline)
|
|
|
|
... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
|
|
... if config.push_to_hub:
|
|
... upload_folder(
|
|
... repo_id=repo_id,
|
|
... folder_path=config.output_dir,
|
|
... commit_message=f"Epoch {epoch}",
|
|
... ignore_patterns=["step_*", "epoch_*"],
|
|
... )
|
|
... else:
|
|
... pipeline.save_pretrained(config.output_dir)
|
|
```
|
|
|
|
ํด, ์ฝ๋๊ฐ ๊ฝค ๋ง์๋ค์! ํ์ง๋ง ๐ค Accelerate์ [`~accelerate.notebook_launcher`] ํจ์์ ํ์ต์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค. ํจ์์ ํ์ต ๋ฃจํ, ๋ชจ๋ ํ์ต ์ธ์, ํ์ต์ ์ฌ์ฉํ ํ๋ก์ธ์ค ์(์ฌ์ฉ ๊ฐ๋ฅํ GPU์ ์๋ฅผ ๋ณ๊ฒฝํ ์ ์์)๋ฅผ ์ ๋ฌํฉ๋๋ค:
|
|
|
|
```py
|
|
>>> from accelerate import notebook_launcher
|
|
|
|
>>> args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
|
|
|
|
>>> notebook_launcher(train_loop, args, num_processes=1)
|
|
```
|
|
|
|
ํ๋ฒ ํ์ต์ด ์๋ฃ๋๋ฉด, diffusion ๋ชจ๋ธ๋ก ์์ฑ๋ ์ต์ข
๐ฆ์ด๋ฏธ์ง๐ฆ๋ฅผ ํ์ธํด๋ณด๊ธธ ๋ฐ๋๋๋ค!
|
|
|
|
```py
|
|
>>> import glob
|
|
|
|
>>> sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
|
|
>>> Image.open(sample_images[-1])
|
|
```
|
|
|
|

|
|
|
|
## ๋ค์ ๋จ๊ณ
|
|
|
|
Unconditional ์ด๋ฏธ์ง ์์ฑ์ ํ์ต๋ ์ ์๋ ์์
์ค ํ๋์ ์์์
๋๋ค. ๋ค๋ฅธ ์์
๊ณผ ํ์ต ๋ฐฉ๋ฒ์ [๐งจ Diffusers ํ์ต ์์](../training/overview) ํ์ด์ง์์ ํ์ธํ ์ ์์ต๋๋ค. ๋ค์์ ํ์ตํ ์ ์๋ ๋ช ๊ฐ์ง ์์์
๋๋ค:
|
|
|
|
- [Textual Inversion](../training/text_inversion), ํน์ ์๊ฐ์ ๊ฐ๋
์ ํ์ต์์ผ ์์ฑ๋ ์ด๋ฏธ์ง์ ํตํฉ์ํค๋ ์๊ณ ๋ฆฌ์ฆ์
๋๋ค.
|
|
- [DreamBooth](../training/dreambooth), ์ฃผ์ ์ ๋ํ ๋ช ๊ฐ์ง ์
๋ ฅ ์ด๋ฏธ์ง๋ค์ด ์ฃผ์ด์ง๋ฉด ์ฃผ์ ์ ๋ํ ๊ฐ์ธํ๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ธฐ ์ํ ๊ธฐ์ ์
๋๋ค.
|
|
- [Guide](../training/text2image) ๋ฐ์ดํฐ์
์ Stable Diffusion ๋ชจ๋ธ์ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์
๋๋ค.
|
|
- [Guide](../training/lora) LoRA๋ฅผ ์ฌ์ฉํด ๋งค์ฐ ํฐ ๋ชจ๋ธ์ ๋น ๋ฅด๊ฒ ํ์ธํ๋ํ๊ธฐ ์ํ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ธ ๊ธฐ์ ์
๋๋ค.
|