1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] add LoRA support to HiDream and fine-tuning script (#11281)

* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com>

* move prompt embeds, pooled embeds outside

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: hlky <hlky@hlky.ac>

* fix import

* fix import and tokenizer 4, text encoder 4 loading

* te

* prompt embeds

* fix naming

* shapes

* initial commit to add HiDreamImageLoraLoaderMixin

* fix init

* add tests

* loader

* fix model input

* add code example to readme

* fix default max length of text encoders

* prints

* nullify training cond in unpatchify for temp fix to incompatible shaping of transformer output during training

* smol fix

* unpatchify

* unpatchify

* fix validation

* flip pred and loss

* fix shift!!!

* revert unpatchify changes (for now)

* smol fix

* Apply style fixes

* workaround moe training

* workaround moe training

* remove prints

* to reduce some memory, keep vae in `weight_dtype` same as we have for flux (as it's the same vae)
bbd0c161b5/examples/dreambooth/train_dreambooth_lora_flux.py (L1207)

* refactor to align with HiDream refactor

* refactor to align with HiDream refactor

* refactor to align with HiDream refactor

* add support for cpu offloading of text encoders

* Apply style fixes

* adjust lr and rank for train example

* fix copies

* Apply style fixes

* update README

* update README

* update README

* fix license

* keep prompt2,3,4 as None in validation

* remove reverse ode comment

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* vae offload change

* fix text encoder offloading

* Apply style fixes

* cleaner to_kwargs

* fix module name in copied from

* add requirements

* fix offloading

* fix offloading

* fix offloading

* update transformers version in reqs

* try AutoTokenizer

* try AutoTokenizer

* Apply style fixes

* empty commit

* Delete tests/lora/test_lora_layers_hidream.py

* change tokenizer_4 to load with AutoTokenizer as well

* make text_encoder_four and tokenizer_four configurable

* save model card

* save model card

* revert T5

* fix test

* remove non diffusers lumina2 conversion

---------

Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Linoy Tsaban
2025-04-22 11:44:02 +03:00
committed by GitHub
parent 6ab62c7431
commit e30d3bf544
10 changed files with 2437 additions and 7 deletions

View File

@@ -28,6 +28,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
<Tip>
@@ -91,6 +92,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
## HiDreamImageLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin

View File

@@ -0,0 +1,133 @@
# DreamBooth training example for HiDream Image
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
The `train_dreambooth_lora_hidream.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/).
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_sana.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
> [!NOTE]
> The following training configuration prioritizes lower memory consumption by using gradient checkpointing,
> 8-bit Adam optimizer, latent caching, offloading, no validation.
> Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image)
> text embeddings are pre-computed to save memory.
```bash
export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-hidream-lora"
accelerate launch train_dreambooth_lora_hidream.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--use_8bit_adam \
--rank=16 \
--learning_rate=2e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--cache_latents \
--gradient_checkpointing \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
For using `push_to_hub`, make you're logged into your Hugging Face account:
```bash
huggingface-cli login
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
## Notes
Additionally, we welcome you to explore the following CLI arguments:
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
* `--rank`: The rank of the LoRA layers. The higher the rank, the more parameters are trained. The default is 16.
We provide several options for optimizing memory optimization:
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
* `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.

View File

@@ -0,0 +1,8 @@
accelerate>=1.4.0
torchvision
transformers>=4.50.0
ftfy
tensorboard
Jinja2
peft>=0.14.0
sentencepiece

View File

@@ -0,0 +1,220 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
pretrained_model_name_or_path = "hf-internal-testing/tiny-hidream-i1-pipe"
text_encoder_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
script_path = "examples/dreambooth/train_dreambooth_lora_hidream.py"
transformer_layer_type = "double_stream_blocks.0.block.attn1.to_k"
def test_dreambooth_lora_hidream(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# `self.transformer_layer_type` should be in the state dict.
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 16
""".split()
resume_run_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

File diff suppressed because it is too large Load Diff

View File

@@ -77,6 +77,7 @@ if is_torch_available():
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
@@ -108,6 +109,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,

View File

@@ -5360,6 +5360,325 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`HiDreamImageTransformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."

View File

@@ -56,6 +56,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
}

View File

@@ -275,7 +275,14 @@ class HiDreamAttnProcessor:
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
def __init__(
self,
embed_dim,
num_routed_experts=4,
num_activated_experts=2,
aux_loss_alpha=0.01,
_force_inference_output=False,
):
super().__init__()
self.top_k = num_activated_experts
self.n_routed_experts = num_routed_experts
@@ -289,9 +296,10 @@ class MoEGate(nn.Module):
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
self._force_inference_output = _force_inference_output
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# print(bsz, seq_len, h)
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
@@ -309,7 +317,7 @@ class MoEGate(nn.Module):
topk_weight = topk_weight / denominator
### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
if self.training and self.alpha > 0.0 and not self._force_inference_output:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
@@ -341,14 +349,19 @@ class MOEFeedForwardSwiGLU(nn.Module):
hidden_dim: int,
num_routed_experts: int,
num_activated_experts: int,
_force_inference_output: bool = False,
):
super().__init__()
self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2)
self.experts = nn.ModuleList(
[HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]
)
self._force_inference_output = _force_inference_output
self.gate = MoEGate(
embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts
embed_dim=dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=_force_inference_output,
)
self.num_activated_experts = num_activated_experts
@@ -359,7 +372,7 @@ class MOEFeedForwardSwiGLU(nn.Module):
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
if self.training and not self._force_inference_output:
x = x.repeat_interleave(self.num_activated_experts, dim=0)
y = torch.empty_like(x, dtype=wtype)
for i, expert in enumerate(self.experts):
@@ -413,6 +426,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
_force_inference_output: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -436,6 +450,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
hidden_dim=4 * dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=_force_inference_output,
)
else:
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
@@ -480,6 +495,7 @@ class HiDreamImageTransformerBlock(nn.Module):
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
_force_inference_output: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -504,6 +520,7 @@ class HiDreamImageTransformerBlock(nn.Module):
hidden_dim=4 * dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=_force_inference_output,
)
else:
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
@@ -606,6 +623,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
axes_dims_rope: Tuple[int, int] = (32, 32),
max_resolution: Tuple[int, int] = (128, 128),
llama_layers: List[int] = None,
force_inference_output: bool = False,
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -629,6 +647,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=force_inference_output,
)
)
for _ in range(num_layers)
@@ -644,6 +663,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=force_inference_output,
)
)
for _ in range(num_single_layers)
@@ -662,7 +682,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
if is_training:
if is_training and not self.config.force_inference_output:
B, S, F = x.shape
C = F // (self.config.patch_size * self.config.patch_size)
x = (

View File

@@ -13,6 +13,7 @@ from transformers import (
)
from ...image_processor import VaeImageProcessor
from ...loaders import HiDreamImageLoraLoaderMixin
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
@@ -142,7 +143,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class HiDreamImagePipeline(DiffusionPipeline):
class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"]