mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Adding VQGAN Training script (#5483)
* Init commit * Removed einops * Added default movq config for training * Update explanation of prompts * Fixed inheritance of discriminator and init_tracker * Fixed incompatible api between muse and here * Fixed output * Setup init training * Basic structure done * Removed attention for quick tests * Style fixes * Fixed vae/vqgan styles * Removed redefinition of wandb * Fixed log_validation and tqdm * Nothing commit * Added commit loss to lookup_from_codebook * Update src/diffusers/models/vq_model.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Adding perliminary README * Fixed one typo * Local changes * Fixed main issues * Merging * Update src/diffusers/models/vq_model.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Testing+Fixed bugs in training script * Some style fixes * Added wandb to docs * Fixed timm test * get testing suite ready. * remove return loss * remove return_loss * Remove diffs * Remove diffs * fix ruff format --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
2
.github/workflows/pr_tests.yml
vendored
2
.github/workflows/pr_tests.yml
vendored
@@ -156,7 +156,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install peft
|
||||
python -m uv pip install peft timm
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
examples
|
||||
|
||||
1
.github/workflows/push_tests.yml
vendored
1
.github/workflows/push_tests.yml
vendored
@@ -426,6 +426,7 @@ jobs:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install timm
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
|
||||
2
.github/workflows/push_tests_fast.yml
vendored
2
.github/workflows/push_tests_fast.yml
vendored
@@ -107,7 +107,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install peft
|
||||
python -m uv pip install peft timm
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
examples
|
||||
|
||||
127
examples/vqgan/README.md
Normal file
127
examples/vqgan/README.md
Normal file
@@ -0,0 +1,127 @@
|
||||
## Training an VQGAN VAE
|
||||
VQVAEs were first introduced in [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file).
|
||||
|
||||
|
||||
Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
### Training on CIFAR10
|
||||
|
||||
The command to train a VQGAN model on cifar10 dataset:
|
||||
|
||||
```bash
|
||||
accelerate launch train_vqgan.py \
|
||||
--dataset_name=cifar10 \
|
||||
--image_column=img \
|
||||
--validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
|
||||
--resolution=128 \
|
||||
--train_batch_size=2 \
|
||||
--gradient_accumulation_steps=8 \
|
||||
--report_to=wandb
|
||||
```
|
||||
|
||||
An example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images).
|
||||
The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below.
|
||||
|
||||
# Modifying the architecture
|
||||
|
||||
To modify the architecture of the vqgan model you can save the config taken from [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder/blob/main/movq/config.json) and then provide that to the script with the option --model_config_name_or_path. This config is below
|
||||
```
|
||||
{
|
||||
"_class_name": "VQModel",
|
||||
"_diffusers_version": "0.17.0.dev0",
|
||||
"act_fn": "silu",
|
||||
"block_out_channels": [
|
||||
128,
|
||||
256,
|
||||
256,
|
||||
512
|
||||
],
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"AttnDownEncoderBlock2D"
|
||||
],
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"norm_num_groups": 32,
|
||||
"norm_type": "spatial",
|
||||
"num_vq_embeddings": 16384,
|
||||
"out_channels": 3,
|
||||
"sample_size": 32,
|
||||
"scaling_factor": 0.18215,
|
||||
"up_block_types": [
|
||||
"AttnUpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D"
|
||||
],
|
||||
"vq_embed_dim": 4
|
||||
}
|
||||
```
|
||||
To lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below
|
||||
```
|
||||
{
|
||||
"_class_name": "VQModel",
|
||||
"_diffusers_version": "0.17.0.dev0",
|
||||
"act_fn": "silu",
|
||||
"block_out_channels": [
|
||||
128,
|
||||
256,
|
||||
256,
|
||||
],
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"norm_num_groups": 32,
|
||||
"norm_type": "spatial",
|
||||
"num_vq_embeddings": 16384,
|
||||
"out_channels": 3,
|
||||
"sample_size": 32,
|
||||
"scaling_factor": 0.18215,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D"
|
||||
],
|
||||
"vq_embed_dim": 4
|
||||
}
|
||||
```
|
||||
For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.
|
||||
|
||||
## Extra training tips/ideas
|
||||
During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646)
|
||||
Secondly, training should seem to be done when both the discriminator and the generator loss converges.
|
||||
Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it.
|
||||
Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, becareful with the feature map norms since this can easily overdominate the loss.
|
||||
48
examples/vqgan/discriminator.py
Normal file
48
examples/vqgan/discriminator.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Ported from Paella
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
|
||||
class Discriminator(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(
|
||||
nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = hidden_channels // (2 ** max((d - i), 0))
|
||||
c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d(
|
||||
(hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
|
||||
)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
x = self.encoder(x)
|
||||
if cond is not None:
|
||||
cond = cond.view(
|
||||
cond.size(0),
|
||||
cond.size(1),
|
||||
1,
|
||||
1,
|
||||
).expand(-1, -1, x.size(-2), x.size(-1))
|
||||
x = torch.cat([x, cond], dim=1)
|
||||
x = self.shuffle(x)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
8
examples/vqgan/requirements.txt
Normal file
8
examples/vqgan/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
datasets
|
||||
timm
|
||||
numpy
|
||||
tqdm
|
||||
tensorboard
|
||||
395
examples/vqgan/test_vqgan.py
Normal file
395
examples/vqgan/test_vqgan.py
Normal file
@@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import VQModel
|
||||
from diffusers.utils.testing_utils import require_timm
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
@require_timm
|
||||
class TextToImage(ExamplesTestsAccelerate):
|
||||
@property
|
||||
def test_vqmodel_config(self):
|
||||
return {
|
||||
"_class_name": "VQModel",
|
||||
"_diffusers_version": "0.17.0.dev0",
|
||||
"act_fn": "silu",
|
||||
"block_out_channels": [
|
||||
32,
|
||||
],
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"norm_num_groups": 32,
|
||||
"norm_type": "spatial",
|
||||
"num_vq_embeddings": 32,
|
||||
"out_channels": 3,
|
||||
"sample_size": 32,
|
||||
"scaling_factor": 0.18215,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
"vq_embed_dim": 4,
|
||||
}
|
||||
|
||||
@property
|
||||
def test_discriminator_config(self):
|
||||
return {
|
||||
"_class_name": "Discriminator",
|
||||
"_diffusers_version": "0.27.0.dev0",
|
||||
"in_channels": 3,
|
||||
"cond_channels": 0,
|
||||
"hidden_channels": 8,
|
||||
"depth": 4,
|
||||
}
|
||||
|
||||
def get_vq_and_discriminator_configs(self, tmpdir):
|
||||
vqmodel_config_path = os.path.join(tmpdir, "vqmodel.json")
|
||||
discriminator_config_path = os.path.join(tmpdir, "discriminator.json")
|
||||
with open(vqmodel_config_path, "w") as fp:
|
||||
json.dump(self.test_vqmodel_config, fp)
|
||||
with open(discriminator_config_path, "w") as fp:
|
||||
json.dump(self.test_discriminator_config, fp)
|
||||
return vqmodel_config_path, discriminator_config_path
|
||||
|
||||
def test_vqmodel(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
|
||||
test_args = f"""
|
||||
examples/vqgan/train_vqgan.py
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 32
|
||||
--image_column image
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(
|
||||
os.path.isfile(os.path.join(tmpdir, "discriminator", "diffusion_pytorch_model.safetensors"))
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "vqmodel", "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
def test_vqmodel_checkpointing(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/vqgan/train_vqgan.py
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 32
|
||||
--image_column image
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--checkpointing_steps=2
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
# check can run an intermediate checkpoint
|
||||
model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel")
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
_ = model(image)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4"},
|
||||
)
|
||||
|
||||
# Run training script for 2 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/vqgan/train_vqgan.py
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 32
|
||||
--image_column image
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
_ = model(image)
|
||||
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
# In the current script, checkpointing_steps 1 is equivalent to checkpointing_steps 2 as after the generator gets trained for one step,
|
||||
# the discriminator gets trained and loss and saving happens after that. Thus we do not expect to get a checkpoint-5
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_vqmodel_checkpointing_use_ema(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/vqgan/train_vqgan.py
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 32
|
||||
--image_column image
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--checkpointing_steps=2
|
||||
--output_dir {tmpdir}
|
||||
--use_ema
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
_ = model(image)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
# check can run an intermediate checkpoint
|
||||
model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel")
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
_ = model(image)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
|
||||
# Run training script for 2 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/vqgan/train_vqgan.py
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 32
|
||||
--image_column image
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
|
||||
--output_dir {tmpdir}
|
||||
--use_ema
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
_ = model(image)
|
||||
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_vqmodel_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/vqgan/train_vqgan.py
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 32
|
||||
--image_column image
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
_ = model(image)
|
||||
|
||||
# check checkpoint directories exist
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
|
||||
def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/vqgan/train_vqgan.py
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 32
|
||||
--image_column image
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--checkpointing_steps=2
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
_ = model(image)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
# resume and we should try to checkpoint at 6, where we'll have to remove
|
||||
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/vqgan/train_vqgan.py
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 32
|
||||
--image_column image
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 8
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--model_config_name_or_path {vqmodel_config_path}
|
||||
--discriminator_config_name_or_path {discriminator_config_path}
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
|
||||
--checkpoints_total_limit=2
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
_ = model(image)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
1067
examples/vqgan/train_vqgan.py
Normal file
1067
examples/vqgan/train_vqgan.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -41,6 +41,7 @@ class DecoderOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
sample: torch.Tensor
|
||||
commit_loss: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
@@ -142,18 +142,20 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, _, _ = self.quantize(h)
|
||||
quant, commit_loss, _ = self.quantize(h)
|
||||
elif self.config.lookup_from_codebook:
|
||||
quant = self.quantize.get_codebook_entry(h, shape)
|
||||
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
||||
else:
|
||||
quant = h
|
||||
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
||||
quant2 = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return dec, commit_loss
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, return_dict: bool = True
|
||||
@@ -173,9 +175,8 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
h = self.encode(sample).latents
|
||||
dec = self.decode(h).sample
|
||||
dec = self.decode(h)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
return dec.sample, dec.commit_loss
|
||||
return dec
|
||||
|
||||
@@ -76,6 +76,7 @@ from .import_utils import (
|
||||
is_safetensors_available,
|
||||
is_scipy_available,
|
||||
is_tensorboard_available,
|
||||
is_timm_available,
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
|
||||
@@ -295,6 +295,19 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torchvision_available = False
|
||||
|
||||
_timm_available = importlib.util.find_spec("timm") is not None
|
||||
if _timm_available:
|
||||
try:
|
||||
_timm_version = importlib_metadata.version("timm")
|
||||
logger.info(f"Timm version {_timm_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_timm_available = False
|
||||
|
||||
|
||||
def is_timm_available():
|
||||
return _timm_available
|
||||
|
||||
|
||||
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
|
||||
try:
|
||||
_bitsandbytes_version = importlib_metadata.version("bitsandbytes")
|
||||
|
||||
@@ -33,6 +33,7 @@ from .import_utils import (
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_peft_available,
|
||||
is_timm_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_torchsde_available,
|
||||
@@ -340,6 +341,13 @@ def require_peft_backend(test_case):
|
||||
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
|
||||
|
||||
|
||||
def require_timm(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
|
||||
|
||||
|
||||
def require_peft_version_greater(peft_version):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
|
||||
|
||||
@@ -98,3 +98,19 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
def test_loss_pretrained(self):
|
||||
model = VQModel.from_pretrained("fusing/vqgan-dummy")
|
||||
model.to(torch_device).eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
backend_manual_seed(torch_device, 0)
|
||||
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
image = image.to(torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(image).commit_loss.cpu()
|
||||
# fmt: off
|
||||
expected_output = torch.tensor([0.1936])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output, expected_output, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user