1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Docs] Update and make improvements (#5819)

Update and make improvements
This commit is contained in:
M. Tolga Cangöz
2023-11-17 00:47:25 +03:00
committed by GitHub
parent a042909c83
commit c697f52476
7 changed files with 18 additions and 20 deletions

View File

@@ -47,7 +47,7 @@ limitations under the License.
## Installation
We recommend installing 🤗 Diffusers in a virtual environment from PyPi or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/) and [Flax](https://flax.readthedocs.io/en/latest/#installation), please refer to their official documentation.
We recommend installing 🤗 Diffusers in a virtual environment from PyPI or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/) and [Flax](https://flax.readthedocs.io/en/latest/#installation), please refer to their official documentation.
### PyTorch
@@ -77,7 +77,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
## Quickstart
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 14000+ checkpoints):
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 15000+ checkpoints):
```python
from diffusers import DiffusionPipeline
@@ -94,14 +94,13 @@ You can also dig into the models and schedulers toolbox to build your own diffus
from diffusers import DDPMScheduler, UNet2DModel
from PIL import Image
import torch
import numpy as np
scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
scheduler.set_timesteps(50)
sample_size = model.config.sample_size
noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
input = noise
for t in scheduler.timesteps:
@@ -136,8 +135,7 @@ You can look out for [issues](https://github.com/huggingface/diffusers/issues) y
- See [New model/pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) to contribute exciting new diffusion models / diffusion pipelines
- See [New scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz98XR"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a>. We discuss the hottest trends about diffusion models, help each other with contributions, personal projects or
just hang out ☕.
Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz98XR"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a>. We discuss the hottest trends about diffusion models, help each other with contributions, personal projects or just hang out ☕.
## Popular Tasks & Pipelines

View File

@@ -194,9 +194,9 @@ unet_runs_per_experiment = 50
# load inputs
def generate_inputs():
sample = torch.randn(2, 4, 64, 64).half().cuda()
timestep = torch.rand(1).half().cuda() * 999
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
return sample, timestep, encoder_hidden_states

View File

@@ -321,13 +321,13 @@ Now you can wrap all these components together in a training loop with 🤗 Acce
... for step, batch in enumerate(train_dataloader):
... clean_images = batch["images"]
... # Sample noise to add to the images
... noise = torch.randn(clean_images.shape).to(clean_images.device)
... noise = torch.randn(clean_images.shape, device=clean_images.device)
... bs = clean_images.shape[0]
... # Sample a random timestep for each image
... timesteps = torch.randint(
... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
... ).long()
... )
... # Add noise to the clean images according to the noise magnitude at each timestep
... # (this is the forward diffusion process)

View File

@@ -71,7 +71,7 @@ tensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720,
>>> import torch
>>> sample_size = model.config.sample_size
>>> noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
>>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
```
5. Now write a loop to iterate over the timesteps. At each timestep, the model does a [`UNet2DModel.forward`] pass and returns the noisy residual. The scheduler's [`~DDPMScheduler.step`] method takes the noisy residual, timestep, and input and it predicts the image at the previous timestep. This output becomes the next input to the model in the denoising loop, and it'll repeat until it reaches the end of the `timesteps` array.
@@ -216,8 +216,8 @@ Next, generate some initial random noise as a starting point for the diffusion p
>>> latents = torch.randn(
... (batch_size, unet.config.in_channels, height // 8, width // 8),
... generator=generator,
... device=torch_device,
... )
>>> latents = latents.to(torch_device)
```
### Denoise the image

View File

@@ -273,9 +273,9 @@ unet_runs_per_experiment = 50
# 입력 불러오기
def generate_inputs():
sample = torch.randn(2, 4, 64, 64).half().cuda()
timestep = torch.rand(1).half().cuda() * 999
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
return sample, timestep, encoder_hidden_states

View File

@@ -322,13 +322,13 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽
... for step, batch in enumerate(train_dataloader):
... clean_images = batch["images"]
... # 이미지에 더할 노이즈를 샘플링합니다.
... noise = torch.randn(clean_images.shape).to(clean_images.device)
... noise = torch.randn(clean_images.shape, device=clean_images.device)
... bs = clean_images.shape[0]
... # 각 이미지를 위한 랜덤한 타임스텝(timestep)을 샘플링합니다.
... timesteps = torch.randint(
... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
... ).long()
... )
... # 각 타임스텝의 노이즈 크기에 따라 깨끗한 이미지에 노이즈를 추가합니다.
... # (이는 foward diffusion 과정입니다.)

View File

@@ -71,7 +71,7 @@ specific language governing permissions and limitations under the License.
>>> import torch
>>> sample_size = model.config.sample_size
>>> noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
>>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
```
5. 이제 timestep을 반복하는 루프를 작성합니다. 각 timestep에서 모델은 [`UNet2DModel.forward`]를 통해 noisy residual을 반환합니다. 스케줄러의 [`~DDPMScheduler.step`] 메서드는 noisy residual, timestep, 그리고 입력을 받아 이전 timestep에서 이미지를 예측합니다. 이 출력은 노이즈 제거 루프의 모델에 대한 다음 입력이 되며, `timesteps` 배열의 끝에 도달할 때까지 반복됩니다.
@@ -212,8 +212,8 @@ Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent di
>>> latents = torch.randn(
... (batch_size, unet.in_channels, height // 8, width // 8),
... generator=generator,
... device=torch_device,
... )
>>> latents = latents.to(torch_device)
```
### 이미지 노이즈 제거