mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into training-group-offloading-tests
This commit is contained in:
95
examples/research_projects/sana/README.md
Normal file
95
examples/research_projects/sana/README.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# Training SANA Sprint Diffuser
|
||||
|
||||
This README explains how to use the provided bash script commands to download a pre-trained teacher diffuser model and train it on a specific dataset, following the [SANA Sprint methodology](https://arxiv.org/abs/2503.09641).
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Define the local paths
|
||||
|
||||
Set a variable for your desired output directory. This directory will store the downloaded model and the training checkpoints/results.
|
||||
|
||||
```bash
|
||||
your_local_path='output' # Or any other path you prefer
|
||||
mkdir -p $your_local_path # Create the directory if it doesn't exist
|
||||
```
|
||||
|
||||
### 2. Download the pre-trained model
|
||||
|
||||
Download the SANA Sprint teacher model from Hugging Face Hub. The script uses the 1.6B parameter model.
|
||||
|
||||
```bash
|
||||
huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
|
||||
```
|
||||
|
||||
*(Optional: You can also download the 0.6B model by replacing the model name: `Efficient-Large-Model/Sana_Sprint_0.6B_1024px_teacher_diffusers`)*
|
||||
|
||||
### 3. Acquire the dataset shards
|
||||
|
||||
The training script in this example uses specific `.parquet` shards from a randomly selected `brivangl/midjourney-v6-llava` dataset instead of downloading the entire dataset automatically via `dataset_name`.
|
||||
|
||||
The script specifically uses these three files:
|
||||
* `data/train_000.parquet`
|
||||
* `data/train_001.parquet`
|
||||
* `data/train_002.parquet`
|
||||
|
||||
|
||||
|
||||
You can either:
|
||||
|
||||
Let the script download the dataset automatically during first run
|
||||
|
||||
Or download it manually
|
||||
|
||||
**Note:** The full `brivangl/midjourney-v6-llava` dataset is much larger and contains many more shards. This script example explicitly trains *only* on the three specified shards.
|
||||
|
||||
## Usage
|
||||
|
||||
Once the model is downloaded, you can run the training script.
|
||||
|
||||
```bash
|
||||
|
||||
your_local_path='output' # Ensure this variable is set
|
||||
|
||||
python train_sana_sprint_diffusers.py \
|
||||
--pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \
|
||||
--output_dir=$your_local_path \
|
||||
--mixed_precision=bf16 \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-6 \
|
||||
--max_train_steps=30000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--dataset_name='brivangl/midjourney-v6-llava' \
|
||||
--file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \
|
||||
--checkpointing_steps=500 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--seed=453645634 \
|
||||
--train_largest_timestep \
|
||||
--misaligned_pairs_D \
|
||||
--gradient_checkpointing \
|
||||
--resume_from_checkpoint="latest" \
|
||||
```
|
||||
|
||||
### Explanation of parameters
|
||||
|
||||
* `--pretrained_model_name_or_path`: Path to the downloaded pre-trained model directory.
|
||||
* `--output_dir`: Directory where training logs, checkpoints, and the final model will be saved.
|
||||
* `--mixed_precision`: Use BF16 mixed precision for training, which can save memory and speed up training on compatible hardware.
|
||||
* `--resolution`: The image resolution used for training (1024x1024).
|
||||
* `--learning_rate`: The learning rate for the optimizer.
|
||||
* `--max_train_steps`: The total number of training steps to perform.
|
||||
* `--dataloader_num_workers`: Number of worker processes for loading data. Increase for faster data loading if your CPU and disk can handle it.
|
||||
* `--dataset_name`: The name of the dataset on Hugging Face Hub (`brivangl/midjourney-v6-llava`).
|
||||
* `--file_path`: **Specifies the local paths to the dataset shards to be used for training.** In this case, `data/train_000.parquet`, `data/train_001.parquet`, and `data/train_002.parquet`.
|
||||
* `--checkpointing_steps`: Save a training checkpoint every X steps.
|
||||
* `--checkpoints_total_limit`: Maximum number of checkpoints to keep. Older checkpoints will be deleted.
|
||||
* `--train_batch_size`: The batch size per GPU.
|
||||
* `--gradient_accumulation_steps`: Number of steps to accumulate gradients before performing an optimizer step.
|
||||
* `--seed`: Random seed for reproducibility.
|
||||
* `--train_largest_timestep`: A specific training strategy focusing on larger timesteps.
|
||||
* `--misaligned_pairs_D`: Another specific training strategy to add misaligned image-text pairs as fake data for GAN.
|
||||
* `--gradient_checkpointing`: Enable gradient checkpointing to save GPU memory.
|
||||
* `--resume_from_checkpoint`: Allows resuming training from the latest saved checkpoint in the `--output_dir`.
|
||||
|
||||
|
||||
1781
examples/research_projects/sana/train_sana_sprint_diffusers.py
Normal file
1781
examples/research_projects/sana/train_sana_sprint_diffusers.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,26 @@
|
||||
your_local_path='output'
|
||||
|
||||
huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
|
||||
|
||||
# or Sana_Sprint_0.6B_1024px_teacher_diffusers
|
||||
|
||||
python train_sana_sprint_diffusers.py \
|
||||
--pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \
|
||||
--output_dir=$your_local_path \
|
||||
--mixed_precision=bf16 \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-6 \
|
||||
--max_train_steps=30000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--dataset_name='brivangl/midjourney-v6-llava' \
|
||||
--file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \
|
||||
--checkpointing_steps=500 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--seed=453645634 \
|
||||
--train_largest_timestep \
|
||||
--misaligned_pairs_D \
|
||||
--gradient_checkpointing \
|
||||
--resume_from_checkpoint="latest" \
|
||||
|
||||
|
||||
@@ -18,9 +18,9 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torchvision_available
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import Timesteps
|
||||
@@ -29,6 +29,10 @@ from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class CosmosPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True
|
||||
|
||||
@@ -40,6 +40,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.import_utils import is_transformers_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
@@ -312,8 +313,19 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.0.dev0"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user