mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
221 lines
10 KiB
Markdown
221 lines
10 KiB
Markdown
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
|
|
# Text-to-image
|
|
|
|
> [!WARNING]
|
|
> text-to-image ํ์ธํ๋ ์คํฌ๋ฆฝํธ๋ experimental ์ํ์
๋๋ค. ๊ณผ์ ํฉํ๊ธฐ ์ฝ๊ณ ์น๋ช
์ ์ธ ๋ง๊ฐ๊ณผ ๊ฐ์ ๋ฌธ์ ์ ๋ถ๋ชํ๊ธฐ ์ฝ์ต๋๋ค. ์์ฒด ๋ฐ์ดํฐ์
์์ ์ต์์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ผ๋ ค๋ฉด ๋ค์ํ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํ์ํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
|
|
|
|
Stable Diffusion๊ณผ ๊ฐ์ text-to-image ๋ชจ๋ธ์ ํ
์คํธ ํ๋กฌํํธ์์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค. ์ด ๊ฐ์ด๋๋ PyTorch ๋ฐ Flax๋ฅผ ์ฌ์ฉํ์ฌ ์์ฒด ๋ฐ์ดํฐ์
์์ [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) ๋ชจ๋ธ๋ก ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. ์ด ๊ฐ์ด๋์ ์ฌ์ฉ๋ text-to-image ํ์ธํ๋์ ์ํ ๋ชจ๋ ํ์ต ์คํฌ๋ฆฝํธ์ ๊ด์ฌ์ด ์๋ ๊ฒฝ์ฐ ์ด [๋ฆฌํฌ์งํ ๋ฆฌ](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)์์ ์์ธํ ์ฐพ์ ์ ์์ต๋๋ค.
|
|
|
|
์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์, ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํ์ต dependency๋ค์ ์ค์นํด์ผ ํฉ๋๋ค:
|
|
|
|
```bash
|
|
pip install git+https://github.com/huggingface/diffusers.git
|
|
pip install -U -r requirements.txt
|
|
```
|
|
|
|
๊ทธ๋ฆฌ๊ณ [๐คAccelerate](https://github.com/huggingface/accelerate/) ํ๊ฒฝ์ ์ด๊ธฐํํฉ๋๋ค:
|
|
|
|
```bash
|
|
accelerate config
|
|
```
|
|
|
|
๋ฆฌํฌ์งํ ๋ฆฌ๋ฅผ ์ด๋ฏธ ๋ณต์ ํ ๊ฒฝ์ฐ, ์ด ๋จ๊ณ๋ฅผ ์ํํ ํ์๊ฐ ์์ต๋๋ค. ๋์ , ๋ก์ปฌ ์ฒดํฌ์์ ๊ฒฝ๋ก๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช
์ํ ์ ์์ผ๋ฉฐ ๊ฑฐ๊ธฐ์์ ๋ก๋๋ฉ๋๋ค.
|
|
|
|
### ํ๋์จ์ด ์๊ตฌ ์ฌํญ
|
|
|
|
`gradient_checkpointing` ๋ฐ `mixed_precision`์ ์ฌ์ฉํ๋ฉด ๋จ์ผ 24GB GPU์์ ๋ชจ๋ธ์ ํ์ธํ๋ํ ์ ์์ต๋๋ค. ๋ ๋์ `batch_size`์ ๋ ๋น ๋ฅธ ํ๋ จ์ ์ํด์๋ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ 30GB ์ด์์ธ GPU๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. TPU ๋๋ GPU์์ ํ์ธํ๋์ ์ํด JAX๋ Flax๋ฅผ ์ฌ์ฉํ ์๋ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ [์๋](#flax-jax-finetuning)๋ฅผ ์ฐธ์กฐํ์ธ์.
|
|
|
|
xFormers๋ก memory efficient attention์ ํ์ฑํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ํจ์ฌ ๋ ์ค์ผ ์ ์์ต๋๋ค. [xFormers๊ฐ ์ค์น](./optimization/xformers)๋์ด ์๋์ง ํ์ธํ๊ณ `--enable_xformers_memory_efficient_attention`๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช
์ํฉ๋๋ค.
|
|
|
|
xFormers๋ Flax์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
|
|
|
|
## Hub์ ๋ชจ๋ธ ์
๋ก๋ํ๊ธฐ
|
|
|
|
ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ์ถ๊ฐํ์ฌ ๋ชจ๋ธ์ ํ๋ธ์ ์ ์ฅํฉ๋๋ค:
|
|
|
|
```bash
|
|
--push_to_hub
|
|
```
|
|
|
|
|
|
## ์ฒดํฌํฌ์ธํธ ์ ์ฅ ๋ฐ ๋ถ๋ฌ์ค๊ธฐ
|
|
|
|
ํ์ต ์ค ๋ฐ์ํ ์ ์๋ ์ผ์ ๋๋นํ์ฌ ์ ๊ธฐ์ ์ผ๋ก ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํด ๋๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ๋ ค๋ฉด ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ๋ช
์ํฉ๋๋ค.
|
|
|
|
```bash
|
|
--checkpointing_steps=500
|
|
```
|
|
|
|
500์คํ
๋ง๋ค ์ ์ฒด ํ์ต state๊ฐ 'output_dir'์ ํ์ ํด๋์ ์ ์ฅ๋ฉ๋๋ค. ์ฒดํฌํฌ์ธํธ๋ 'checkpoint-'์ ์ง๊ธ๊น์ง ํ์ต๋ step ์์
๋๋ค. ์๋ฅผ ๋ค์ด 'checkpoint-1500'์ 1500 ํ์ต step ํ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์
๋๋ค.
|
|
|
|
ํ์ต์ ์ฌ๊ฐํ๊ธฐ ์ํด ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ค๋ ค๋ฉด '--resume_from_checkpoint' ์ธ์๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช
์ํ๊ณ ์ฌ๊ฐํ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ง์ ํ์ญ์์ค. ์๋ฅผ ๋ค์ด ๋ค์ ์ธ์๋ 1500๊ฐ์ ํ์ต step ํ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์์๋ถํฐ ํ๋ จ์ ์ฌ๊ฐํฉ๋๋ค.
|
|
|
|
```bash
|
|
--resume_from_checkpoint="checkpoint-1500"
|
|
```
|
|
|
|
## ํ์ธํ๋
|
|
|
|
<frameworkcontent>
|
|
<pt>
|
|
๋ค์๊ณผ ๊ฐ์ด [Naruto BLIP ์บก์
](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) ๋ฐ์ดํฐ์
์์ ํ์ธํ๋ ์คํ์ ์ํด [PyTorch ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)๋ฅผ ์คํํฉ๋๋ค:
|
|
|
|
|
|
```bash
|
|
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
|
export dataset_name="lambdalabs/naruto-blip-captions"
|
|
|
|
accelerate launch train_text_to_image.py \
|
|
--pretrained_model_name_or_path=$MODEL_NAME \
|
|
--dataset_name=$dataset_name \
|
|
--use_ema \
|
|
--resolution=512 --center_crop --random_flip \
|
|
--train_batch_size=1 \
|
|
--gradient_accumulation_steps=4 \
|
|
--gradient_checkpointing \
|
|
--mixed_precision="fp16" \
|
|
--max_train_steps=15000 \
|
|
--learning_rate=1e-05 \
|
|
--max_grad_norm=1 \
|
|
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
|
--output_dir="sd-naruto-model"
|
|
```
|
|
|
|
์์ฒด ๋ฐ์ดํฐ์
์ผ๋ก ํ์ธํ๋ํ๋ ค๋ฉด ๐ค [Datasets](https://huggingface.co/docs/datasets/index)์์ ์๊ตฌํ๋ ํ์์ ๋ฐ๋ผ ๋ฐ์ดํฐ์
์ ์ค๋นํ์ธ์. [๋ฐ์ดํฐ์
์ ํ๋ธ์ ์
๋ก๋](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)ํ๊ฑฐ๋ [ํ์ผ๋ค์ด ์๋ ๋ก์ปฌ ํด๋๋ฅผ ์ค๋น](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ ์ ์์ต๋๋ค.
|
|
|
|
์ฌ์ฉ์ ์ปค์คํ
loading logic์ ์ฌ์ฉํ๋ ค๋ฉด ์คํฌ๋ฆฝํธ๋ฅผ ์์ ํ์ญ์์ค. ๋์์ด ๋๋๋ก ์ฝ๋์ ์ ์ ํ ์์น์ ํฌ์ธํฐ๋ฅผ ๋จ๊ฒผ์ต๋๋ค. ๐ค ์๋ ์์ ์คํฌ๋ฆฝํธ๋ `TRAIN_DIR`์ ๋ก์ปฌ ๋ฐ์ดํฐ์
์ผ๋ก๋ฅผ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ๊ณผ `OUTPUT_DIR`์์ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ๋ณด์ฌ์ค๋๋ค:
|
|
|
|
|
|
```bash
|
|
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
|
export TRAIN_DIR="path_to_your_dataset"
|
|
export OUTPUT_DIR="path_to_save_model"
|
|
|
|
accelerate launch train_text_to_image.py \
|
|
--pretrained_model_name_or_path=$MODEL_NAME \
|
|
--train_data_dir=$TRAIN_DIR \
|
|
--use_ema \
|
|
--resolution=512 --center_crop --random_flip \
|
|
--train_batch_size=1 \
|
|
--gradient_accumulation_steps=4 \
|
|
--gradient_checkpointing \
|
|
--mixed_precision="fp16" \
|
|
--max_train_steps=15000 \
|
|
--learning_rate=1e-05 \
|
|
--max_grad_norm=1 \
|
|
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
|
--output_dir=${OUTPUT_DIR}
|
|
```
|
|
|
|
</pt>
|
|
<jax>
|
|
[@duongna211](https://github.com/duongna21)์ ๊ธฐ์ฌ๋ก, Flax๋ฅผ ์ฌ์ฉํด TPU ๋ฐ GPU์์ Stable Diffusion ๋ชจ๋ธ์ ๋ ๋น ๋ฅด๊ฒ ํ์ตํ ์ ์์ต๋๋ค. ์ด๋ TPU ํ๋์จ์ด์์ ๋งค์ฐ ํจ์จ์ ์ด์ง๋ง GPU์์๋ ํ๋ฅญํ๊ฒ ์๋ํฉ๋๋ค. Flax ํ์ต ์คํฌ๋ฆฝํธ๋ gradient checkpointing๋ gradient accumulation๊ณผ ๊ฐ์ ๊ธฐ๋ฅ์ ์์ง ์ง์ํ์ง ์์ผ๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ๊ฐ 30GB ์ด์์ธ GPU ๋๋ TPU v3๊ฐ ํ์ํฉ๋๋ค.
|
|
|
|
์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์ ์๊ตฌ ์ฌํญ์ด ์ค์น๋์ด ์๋์ง ํ์ธํ์ญ์์ค:
|
|
|
|
```bash
|
|
pip install -U -r requirements_flax.txt
|
|
```
|
|
|
|
๊ทธ๋ฌ๋ฉด ๋ค์๊ณผ ๊ฐ์ด [Flax ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py)๋ฅผ ์คํํ ์ ์์ต๋๋ค.
|
|
|
|
```bash
|
|
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
|
|
export dataset_name="lambdalabs/naruto-blip-captions"
|
|
|
|
python train_text_to_image_flax.py \
|
|
--pretrained_model_name_or_path=$MODEL_NAME \
|
|
--dataset_name=$dataset_name \
|
|
--resolution=512 --center_crop --random_flip \
|
|
--train_batch_size=1 \
|
|
--max_train_steps=15000 \
|
|
--learning_rate=1e-05 \
|
|
--max_grad_norm=1 \
|
|
--output_dir="sd-naruto-model"
|
|
```
|
|
|
|
์์ฒด ๋ฐ์ดํฐ์
์ผ๋ก ํ์ธํ๋ํ๋ ค๋ฉด ๐ค [Datasets](https://huggingface.co/docs/datasets/index)์์ ์๊ตฌํ๋ ํ์์ ๋ฐ๋ผ ๋ฐ์ดํฐ์
์ ์ค๋นํ์ธ์. [๋ฐ์ดํฐ์
์ ํ๋ธ์ ์
๋ก๋](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)ํ๊ฑฐ๋ [ํ์ผ๋ค์ด ์๋ ๋ก์ปฌ ํด๋๋ฅผ ์ค๋น](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ ์ ์์ต๋๋ค.
|
|
|
|
์ฌ์ฉ์ ์ปค์คํ
loading logic์ ์ฌ์ฉํ๋ ค๋ฉด ์คํฌ๋ฆฝํธ๋ฅผ ์์ ํ์ญ์์ค. ๋์์ด ๋๋๋ก ์ฝ๋์ ์ ์ ํ ์์น์ ํฌ์ธํฐ๋ฅผ ๋จ๊ฒผ์ต๋๋ค. ๐ค ์๋ ์์ ์คํฌ๋ฆฝํธ๋ `TRAIN_DIR`์ ๋ก์ปฌ ๋ฐ์ดํฐ์
์ผ๋ก๋ฅผ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค:
|
|
|
|
```bash
|
|
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
|
export TRAIN_DIR="path_to_your_dataset"
|
|
|
|
python train_text_to_image_flax.py \
|
|
--pretrained_model_name_or_path=$MODEL_NAME \
|
|
--train_data_dir=$TRAIN_DIR \
|
|
--resolution=512 --center_crop --random_flip \
|
|
--train_batch_size=1 \
|
|
--mixed_precision="fp16" \
|
|
--max_train_steps=15000 \
|
|
--learning_rate=1e-05 \
|
|
--max_grad_norm=1 \
|
|
--output_dir="sd-naruto-model"
|
|
```
|
|
</jax>
|
|
</frameworkcontent>
|
|
|
|
## LoRA
|
|
|
|
Text-to-image ๋ชจ๋ธ ํ์ธํ๋์ ์ํด, ๋๊ท๋ชจ ๋ชจ๋ธ ํ์ต์ ๊ฐ์ํํ๊ธฐ ์ํ ํ์ธํ๋ ๊ธฐ์ ์ธ LoRA(Low-Rank Adaptation of Large Language Models)๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ [LoRA ํ์ต](lora#text-to-image) ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํ์ธ์.
|
|
|
|
## ์ถ๋ก
|
|
|
|
ํ๋ธ์ ๋ชจ๋ธ ๊ฒฝ๋ก ๋๋ ๋ชจ๋ธ ์ด๋ฆ์ [`StableDiffusionPipeline`]์ ์ ๋ฌํ์ฌ ์ถ๋ก ์ ์ํด ํ์ธ ํ๋๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค:
|
|
|
|
<frameworkcontent>
|
|
<pt>
|
|
```python
|
|
from diffusers import StableDiffusionPipeline
|
|
|
|
model_path = "path_to_saved_model"
|
|
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
|
|
pipe.to("cuda")
|
|
|
|
image = pipe(prompt="yoda").images[0]
|
|
image.save("yoda-naruto.png")
|
|
```
|
|
</pt>
|
|
<jax>
|
|
```python
|
|
import jax
|
|
import numpy as np
|
|
from flax.jax_utils import replicate
|
|
from flax.training.common_utils import shard
|
|
from diffusers import FlaxStableDiffusionPipeline
|
|
|
|
model_path = "path_to_saved_model"
|
|
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
|
|
|
|
prompt = "yoda naruto"
|
|
prng_seed = jax.random.PRNGKey(0)
|
|
num_inference_steps = 50
|
|
|
|
num_samples = jax.device_count()
|
|
prompt = num_samples * [prompt]
|
|
prompt_ids = pipeline.prepare_inputs(prompt)
|
|
|
|
# shard inputs and rng
|
|
params = replicate(params)
|
|
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
|
prompt_ids = shard(prompt_ids)
|
|
|
|
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
|
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
|
image.save("yoda-naruto.png")
|
|
```
|
|
</jax>
|
|
</frameworkcontent> |