mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
91 lines
4.3 KiB
Markdown
91 lines
4.3 KiB
Markdown
# ์ฌ๋ฌ GPU๋ฅผ ์ฌ์ฉํ ๋ถ์ฐ ์ถ๋ก
|
|
|
|
๋ถ์ฐ ์ค์ ์์๋ ์ฌ๋ฌ ๊ฐ์ ํ๋กฌํํธ๋ฅผ ๋์์ ์์ฑํ ๋ ์ ์ฉํ ๐ค [Accelerate](https://huggingface.co/docs/accelerate/index) ๋๋ [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html)๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ๋ฌ GPU์์ ์ถ๋ก ์ ์คํํ ์ ์์ต๋๋ค.
|
|
|
|
์ด ๊ฐ์ด๋์์๋ ๋ถ์ฐ ์ถ๋ก ์ ์ํด ๐ค Accelerate์ PyTorch Distributed๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ๋๋ฆฝ๋๋ค.
|
|
|
|
## ๐ค Accelerate
|
|
|
|
๐ค [Accelerate](https://huggingface.co/docs/accelerate/index)๋ ๋ถ์ฐ ์ค์ ์์ ์ถ๋ก ์ ์ฝ๊ฒ ํ๋ จํ๊ฑฐ๋ ์คํํ ์ ์๋๋ก ์ค๊ณ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์
๋๋ค. ๋ถ์ฐ ํ๊ฒฝ ์ค์ ํ๋ก์ธ์ค๋ฅผ ๊ฐ์ํํ์ฌ PyTorch ์ฝ๋์ ์ง์คํ ์ ์๋๋ก ํด์ค๋๋ค.
|
|
|
|
์์ํ๋ ค๋ฉด Python ํ์ผ์ ์์ฑํ๊ณ [`accelerate.PartialState`]๋ฅผ ์ด๊ธฐํํ์ฌ ๋ถ์ฐ ํ๊ฒฝ์ ์์ฑํ๋ฉด, ์ค์ ์ด ์๋์ผ๋ก ๊ฐ์ง๋๋ฏ๋ก `rank` ๋๋ `world_size`๋ฅผ ๋ช
์์ ์ผ๋ก ์ ์ํ ํ์๊ฐ ์์ต๋๋ค. ['DiffusionPipeline`]์ `distributed_state.device`๋ก ์ด๋ํ์ฌ ๊ฐ ํ๋ก์ธ์ค์ GPU๋ฅผ ํ ๋นํฉ๋๋ค.
|
|
|
|
์ด์ ์ปจํ
์คํธ ๊ด๋ฆฌ์๋ก [`~accelerate.PartialState.split_between_processes`] ์ ํธ๋ฆฌํฐ๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ก์ธ์ค ์์ ๋ฐ๋ผ ํ๋กฌํํธ๋ฅผ ์๋์ผ๋ก ๋ถ๋ฐฐํฉ๋๋ค.
|
|
|
|
|
|
```py
|
|
from accelerate import PartialState
|
|
from diffusers import DiffusionPipeline
|
|
|
|
pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
|
distributed_state = PartialState()
|
|
pipeline.to(distributed_state.device)
|
|
|
|
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
|
|
result = pipeline(prompt).images[0]
|
|
result.save(f"result_{distributed_state.process_index}.png")
|
|
```
|
|
|
|
Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:
|
|
|
|
```bash
|
|
accelerate launch run_distributed.py --num_processes=2
|
|
```
|
|
|
|
> [!TIP]
|
|
> ์์ธํ ๋ด์ฉ์ [๐ค Accelerate๋ฅผ ์ฌ์ฉํ ๋ถ์ฐ ์ถ๋ก ](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํ์ธ์.
|
|
|
|
## Pytoerch ๋ถ์ฐ
|
|
|
|
PyTorch๋ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๋ [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)์ ์ง์ํฉ๋๋ค.
|
|
|
|
์์ํ๋ ค๋ฉด Python ํ์ผ์ ์์ฑํ๊ณ `torch.distributed` ๋ฐ `torch.multiprocessing`์ ์ํฌํธํ์ฌ ๋ถ์ฐ ํ๋ก์ธ์ค ๊ทธ๋ฃน์ ์ค์ ํ๊ณ ๊ฐ GPU์์ ์ถ๋ก ์ฉ ํ๋ก์ธ์ค๋ฅผ ์์ฑํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ [`DiffusionPipeline`]๋ ์ด๊ธฐํํด์ผ ํฉ๋๋ค:
|
|
|
|
ํ์ฐ ํ์ดํ๋ผ์ธ์ `rank`๋ก ์ด๋ํ๊ณ `get_rank`๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ํ๋ก์ธ์ค์ GPU๋ฅผ ํ ๋นํ๋ฉด ๊ฐ ํ๋ก์ธ์ค๊ฐ ๋ค๋ฅธ ํ๋กฌํํธ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค:
|
|
|
|
```py
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
from diffusers import DiffusionPipeline
|
|
|
|
sd = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
|
```
|
|
|
|
์ฌ์ฉํ ๋ฐฑ์๋ ์ ํ, ํ์ฌ ํ๋ก์ธ์ค์ `rank`, `world_size` ๋๋ ์ฐธ์ฌํ๋ ํ๋ก์ธ์ค ์๋ก ๋ถ์ฐ ํ๊ฒฝ ์์ฑ์ ์ฒ๋ฆฌํ๋ ํจ์[`init_process_group`]๋ฅผ ๋ง๋ค์ด ์ถ๋ก ์ ์คํํด์ผ ํฉ๋๋ค.
|
|
|
|
2๊ฐ์ GPU์์ ์ถ๋ก ์ ๋ณ๋ ฌ๋ก ์คํํ๋ ๊ฒฝ์ฐ `world_size`๋ 2์
๋๋ค.
|
|
|
|
```py
|
|
def run_inference(rank, world_size):
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
|
|
|
sd.to(rank)
|
|
|
|
if torch.distributed.get_rank() == 0:
|
|
prompt = "a dog"
|
|
elif torch.distributed.get_rank() == 1:
|
|
prompt = "a cat"
|
|
|
|
image = sd(prompt).images[0]
|
|
image.save(f"./{'_'.join(prompt)}.png")
|
|
```
|
|
|
|
๋ถ์ฐ ์ถ๋ก ์ ์คํํ๋ ค๋ฉด [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn)์ ํธ์ถํ์ฌ `world_size`์ ์ ์๋ GPU ์์ ๋ํด `run_inference` ํจ์๋ฅผ ์คํํฉ๋๋ค:
|
|
|
|
```py
|
|
def main():
|
|
world_size = 2
|
|
mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
```
|
|
|
|
์ถ๋ก ์คํฌ๋ฆฝํธ๋ฅผ ์๋ฃํ์ผ๋ฉด `--nproc_per_node` ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ฉํ GPU ์๋ฅผ ์ง์ ํ๊ณ `torchrun`์ ํธ์ถํ์ฌ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํฉ๋๋ค:
|
|
|
|
```bash
|
|
torchrun run_distributed.py --nproc_per_node=2
|
|
``` |