1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
Files
diffusers/docs/source/ko/training/text2image.md
Sayak Paul 30e5e81d58 change to 2024 in the license (#6902)
change to 2024
2024-02-08 08:19:31 -10:00

224 lines
10 KiB
Markdown

<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Text-to-image
<Tip warning={true}>
text-to-image ํŒŒ์ธํŠœ๋‹ ์Šคํฌ๋ฆฝํŠธ๋Š” experimental ์ƒํƒœ์ž…๋‹ˆ๋‹ค. ๊ณผ์ ํ•ฉํ•˜๊ธฐ ์‰ฝ๊ณ  ์น˜๋ช…์ ์ธ ๋ง๊ฐ๊ณผ ๊ฐ™์€ ๋ฌธ์ œ์— ๋ถ€๋”ชํžˆ๊ธฐ ์‰ฝ์Šต๋‹ˆ๋‹ค. ์ž์ฒด ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ตœ์ƒ์˜ ๊ฒฐ๊ณผ๋ฅผ ์–ป์œผ๋ ค๋ฉด ๋‹ค์–‘ํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํƒ์ƒ‰ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.
</Tip>
Stable Diffusion๊ณผ ๊ฐ™์€ text-to-image ๋ชจ๋ธ์€ ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ์—์„œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ฐ€์ด๋“œ๋Š” PyTorch ๋ฐ Flax๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž์ฒด ๋ฐ์ดํ„ฐ์…‹์—์„œ [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) ๋ชจ๋ธ๋กœ ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ์ด ๊ฐ€์ด๋“œ์— ์‚ฌ์šฉ๋œ text-to-image ํŒŒ์ธํŠœ๋‹์„ ์œ„ํ•œ ๋ชจ๋“  ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๊ด€์‹ฌ์ด ์žˆ๋Š” ๊ฒฝ์šฐ ์ด [๋ฆฌํฌ์ง€ํ† ๋ฆฌ](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)์—์„œ ์ž์„ธํžˆ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์ „์—, ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ํ•™์Šต dependency๋“ค์„ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:
```bash
pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt
```
๊ทธ๋ฆฌ๊ณ  [๐Ÿค—Accelerate](https://github.com/huggingface/accelerate/) ํ™˜๊ฒฝ์„ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค:
```bash
accelerate config
```
๋ฆฌํฌ์ง€ํ† ๋ฆฌ๋ฅผ ์ด๋ฏธ ๋ณต์ œํ•œ ๊ฒฝ์šฐ, ์ด ๋‹จ๊ณ„๋ฅผ ์ˆ˜ํ–‰ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ๋Œ€์‹ , ๋กœ์ปฌ ์ฒดํฌ์•„์›ƒ ๊ฒฝ๋กœ๋ฅผ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋ช…์‹œํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ ๊ฑฐ๊ธฐ์—์„œ ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค.
### ํ•˜๋“œ์›จ์–ด ์š”๊ตฌ ์‚ฌํ•ญ
`gradient_checkpointing` ๋ฐ `mixed_precision`์„ ์‚ฌ์šฉํ•˜๋ฉด ๋‹จ์ผ 24GB GPU์—์„œ ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋” ๋†’์€ `batch_size`์™€ ๋” ๋น ๋ฅธ ํ›ˆ๋ จ์„ ์œ„ํ•ด์„œ๋Š” GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ 30GB ์ด์ƒ์ธ GPU๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. TPU ๋˜๋Š” GPU์—์„œ ํŒŒ์ธํŠœ๋‹์„ ์œ„ํ•ด JAX๋‚˜ Flax๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ [์•„๋ž˜](#flax-jax-finetuning)๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.
xFormers๋กœ memory efficient attention์„ ํ™œ์„ฑํ™”ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ํ›จ์”ฌ ๋” ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. [xFormers๊ฐ€ ์„ค์น˜](./optimization/xformers)๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜๊ณ  `--enable_xformers_memory_efficient_attention`๋ฅผ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋ช…์‹œํ•ฉ๋‹ˆ๋‹ค.
xFormers๋Š” Flax์— ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.
## Hub์— ๋ชจ๋ธ ์—…๋กœ๋“œํ•˜๊ธฐ
ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋‹ค์Œ ์ธ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ๋ชจ๋ธ์„ ํ—ˆ๋ธŒ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค:
```bash
--push_to_hub
```
## ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ๋ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
ํ•™์Šต ์ค‘ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ์ผ์— ๋Œ€๋น„ํ•˜์—ฌ ์ •๊ธฐ์ ์œผ๋กœ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•ด ๋‘๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•˜๋ ค๋ฉด ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋‹ค์Œ ์ธ์ˆ˜๋ฅผ ๋ช…์‹œํ•ฉ๋‹ˆ๋‹ค.
```bash
--checkpointing_steps=500
```
500์Šคํ…๋งˆ๋‹ค ์ „์ฒด ํ•™์Šต state๊ฐ€ 'output_dir'์˜ ํ•˜์œ„ ํด๋”์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋Š” 'checkpoint-'์— ์ง€๊ธˆ๊นŒ์ง€ ํ•™์Šต๋œ step ์ˆ˜์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด 'checkpoint-1500'์€ 1500 ํ•™์Šต step ํ›„์— ์ €์žฅ๋œ ์ฒดํฌํฌ์ธํŠธ์ž…๋‹ˆ๋‹ค.
ํ•™์Šต์„ ์žฌ๊ฐœํ•˜๊ธฐ ์œ„ํ•ด ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋ ค๋ฉด '--resume_from_checkpoint' ์ธ์ˆ˜๋ฅผ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋ช…์‹œํ•˜๊ณ  ์žฌ๊ฐœํ•  ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ง€์ •ํ•˜์‹ญ์‹œ์˜ค. ์˜ˆ๋ฅผ ๋“ค์–ด ๋‹ค์Œ ์ธ์ˆ˜๋Š” 1500๊ฐœ์˜ ํ•™์Šต step ํ›„์— ์ €์žฅ๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ๋ถ€ํ„ฐ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•ฉ๋‹ˆ๋‹ค.
```bash
--resume_from_checkpoint="checkpoint-1500"
```
## ํŒŒ์ธํŠœ๋‹
<frameworkcontent>
<pt>
๋‹ค์Œ๊ณผ ๊ฐ™์ด [Pokรฉmon BLIP ์บก์…˜](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) ๋ฐ์ดํ„ฐ์…‹์—์„œ ํŒŒ์ธํŠœ๋‹ ์‹คํ–‰์„ ์œ„ํ•ด [PyTorch ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-pokemon-model"
```
์ž์ฒด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํŒŒ์ธํŠœ๋‹ํ•˜๋ ค๋ฉด ๐Ÿค— [Datasets](https://huggingface.co/docs/datasets/index)์—์„œ ์š”๊ตฌํ•˜๋Š” ํ˜•์‹์— ๋”ฐ๋ผ ๋ฐ์ดํ„ฐ์…‹์„ ์ค€๋น„ํ•˜์„ธ์š”. [๋ฐ์ดํ„ฐ์…‹์„ ํ—ˆ๋ธŒ์— ์—…๋กœ๋“œ](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)ํ•˜๊ฑฐ๋‚˜ [ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋กœ์ปฌ ํด๋”๋ฅผ ์ค€๋น„](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์‚ฌ์šฉ์ž ์ปค์Šคํ…€ loading logic์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ˆ˜์ •ํ•˜์‹ญ์‹œ์˜ค. ๋„์›€์ด ๋˜๋„๋ก ์ฝ”๋“œ์˜ ์ ์ ˆํ•œ ์œ„์น˜์— ํฌ์ธํ„ฐ๋ฅผ ๋‚จ๊ฒผ์Šต๋‹ˆ๋‹ค. ๐Ÿค— ์•„๋ž˜ ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋Š” `TRAIN_DIR`์˜ ๋กœ์ปฌ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ๋ฅผ ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ๋ฐฉ๋ฒ•๊ณผ `OUTPUT_DIR`์—์„œ ๋ชจ๋ธ์„ ์ €์žฅํ•  ์œ„์น˜๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
export OUTPUT_DIR="path_to_save_model"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR}
```
</pt>
<jax>
[@duongna211](https://github.com/duongna21)์˜ ๊ธฐ์—ฌ๋กœ, Flax๋ฅผ ์‚ฌ์šฉํ•ด TPU ๋ฐ GPU์—์„œ Stable Diffusion ๋ชจ๋ธ์„ ๋” ๋น ๋ฅด๊ฒŒ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” TPU ํ•˜๋“œ์›จ์–ด์—์„œ ๋งค์šฐ ํšจ์œจ์ ์ด์ง€๋งŒ GPU์—์„œ๋„ ํ›Œ๋ฅญํ•˜๊ฒŒ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. Flax ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋Š” gradient checkpointing๋‚˜ gradient accumulation๊ณผ ๊ฐ™์€ ๊ธฐ๋Šฅ์„ ์•„์ง ์ง€์›ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ 30GB ์ด์ƒ์ธ GPU ๋˜๋Š” TPU v3๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์ „์— ์š”๊ตฌ ์‚ฌํ•ญ์ด ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์‹ญ์‹œ์˜ค:
```bash
pip install -U -r requirements_flax.txt
```
๊ทธ๋Ÿฌ๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด [Flax ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py)๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export dataset_name="lambdalabs/pokemon-blip-captions"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
```
์ž์ฒด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํŒŒ์ธํŠœ๋‹ํ•˜๋ ค๋ฉด ๐Ÿค— [Datasets](https://huggingface.co/docs/datasets/index)์—์„œ ์š”๊ตฌํ•˜๋Š” ํ˜•์‹์— ๋”ฐ๋ผ ๋ฐ์ดํ„ฐ์…‹์„ ์ค€๋น„ํ•˜์„ธ์š”. [๋ฐ์ดํ„ฐ์…‹์„ ํ—ˆ๋ธŒ์— ์—…๋กœ๋“œ](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)ํ•˜๊ฑฐ๋‚˜ [ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋กœ์ปฌ ํด๋”๋ฅผ ์ค€๋น„](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์‚ฌ์šฉ์ž ์ปค์Šคํ…€ loading logic์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ˆ˜์ •ํ•˜์‹ญ์‹œ์˜ค. ๋„์›€์ด ๋˜๋„๋ก ์ฝ”๋“œ์˜ ์ ์ ˆํ•œ ์œ„์น˜์— ํฌ์ธํ„ฐ๋ฅผ ๋‚จ๊ฒผ์Šต๋‹ˆ๋‹ค. ๐Ÿค— ์•„๋ž˜ ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋Š” `TRAIN_DIR`์˜ ๋กœ์ปฌ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ๋ฅผ ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
```
</jax>
</frameworkcontent>
## LoRA
Text-to-image ๋ชจ๋ธ ํŒŒ์ธํŠœ๋‹์„ ์œ„ํ•ด, ๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ ํ•™์Šต์„ ๊ฐ€์†ํ™”ํ•˜๊ธฐ ์œ„ํ•œ ํŒŒ์ธํŠœ๋‹ ๊ธฐ์ˆ ์ธ LoRA(Low-Rank Adaptation of Large Language Models)๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ [LoRA ํ•™์Šต](lora#text-to-image) ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.
## ์ถ”๋ก 
ํ—ˆ๋ธŒ์˜ ๋ชจ๋ธ ๊ฒฝ๋กœ ๋˜๋Š” ๋ชจ๋ธ ์ด๋ฆ„์„ [`StableDiffusionPipeline`]์— ์ „๋‹ฌํ•˜์—ฌ ์ถ”๋ก ์„ ์œ„ํ•ด ํŒŒ์ธ ํŠœ๋‹๋œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:
<frameworkcontent>
<pt>
```python
from diffusers import StableDiffusionPipeline
model_path = "path_to_saved_model"
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")
image = pipe(prompt="yoda").images[0]
image.save("yoda-pokemon.png")
```
</pt>
<jax>
```python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
model_path = "path_to_saved_model"
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
prompt = "yoda pokemon"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
image.save("yoda-pokemon.png")
```
</jax>
</frameworkcontent>