mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] add LoRA support to HiDream and fine-tuning script (#11281)
* initial commit
* initial commit
* initial commit
* initial commit
* initial commit
* initial commit
* Update examples/dreambooth/train_dreambooth_lora_hidream.py
Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com>
* move prompt embeds, pooled embeds outside
* Update examples/dreambooth/train_dreambooth_lora_hidream.py
Co-authored-by: hlky <hlky@hlky.ac>
* Update examples/dreambooth/train_dreambooth_lora_hidream.py
Co-authored-by: hlky <hlky@hlky.ac>
* fix import
* fix import and tokenizer 4, text encoder 4 loading
* te
* prompt embeds
* fix naming
* shapes
* initial commit to add HiDreamImageLoraLoaderMixin
* fix init
* add tests
* loader
* fix model input
* add code example to readme
* fix default max length of text encoders
* prints
* nullify training cond in unpatchify for temp fix to incompatible shaping of transformer output during training
* smol fix
* unpatchify
* unpatchify
* fix validation
* flip pred and loss
* fix shift!!!
* revert unpatchify changes (for now)
* smol fix
* Apply style fixes
* workaround moe training
* workaround moe training
* remove prints
* to reduce some memory, keep vae in `weight_dtype` same as we have for flux (as it's the same vae)
bbd0c161b5/examples/dreambooth/train_dreambooth_lora_flux.py (L1207)
* refactor to align with HiDream refactor
* refactor to align with HiDream refactor
* refactor to align with HiDream refactor
* add support for cpu offloading of text encoders
* Apply style fixes
* adjust lr and rank for train example
* fix copies
* Apply style fixes
* update README
* update README
* update README
* fix license
* keep prompt2,3,4 as None in validation
* remove reverse ode comment
* Update examples/dreambooth/train_dreambooth_lora_hidream.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Update examples/dreambooth/train_dreambooth_lora_hidream.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* vae offload change
* fix text encoder offloading
* Apply style fixes
* cleaner to_kwargs
* fix module name in copied from
* add requirements
* fix offloading
* fix offloading
* fix offloading
* update transformers version in reqs
* try AutoTokenizer
* try AutoTokenizer
* Apply style fixes
* empty commit
* Delete tests/lora/test_lora_layers_hidream.py
* change tokenizer_4 to load with AutoTokenizer as well
* make text_encoder_four and tokenizer_four configurable
* save model card
* save model card
* revert T5
* fix test
* remove non diffusers lumina2 conversion
---------
Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
|
||||
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
<Tip>
|
||||
@@ -91,6 +92,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
|
||||
|
||||
## HiDreamImageLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
133
examples/dreambooth/README_hidream.md
Normal file
133
examples/dreambooth/README_hidream.md
Normal file
@@ -0,0 +1,133 @@
|
||||
# DreamBooth training example for HiDream Image
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
|
||||
The `train_dreambooth_lora_hidream.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/).
|
||||
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sana.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
Now, we can launch training using:
|
||||
> [!NOTE]
|
||||
> The following training configuration prioritizes lower memory consumption by using gradient checkpointing,
|
||||
> 8-bit Adam optimizer, latent caching, offloading, no validation.
|
||||
> Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image)
|
||||
> text embeddings are pre-computed to save memory.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-hidream-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_hidream.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--use_8bit_adam \
|
||||
--rank=16 \
|
||||
--learning_rate=2e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=1000 \
|
||||
--cache_latents \
|
||||
--gradient_checkpointing \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
For using `push_to_hub`, make you're logged into your Hugging Face account:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
## Notes
|
||||
|
||||
Additionally, we welcome you to explore the following CLI arguments:
|
||||
|
||||
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
|
||||
* `--rank`: The rank of the LoRA layers. The higher the rank, the more parameters are trained. The default is 16.
|
||||
|
||||
We provide several options for optimizing memory optimization:
|
||||
|
||||
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
|
||||
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
|
||||
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
|
||||
* `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done.
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
|
||||
8
examples/dreambooth/requirements_hidream.txt
Normal file
8
examples/dreambooth/requirements_hidream.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
accelerate>=1.4.0
|
||||
torchvision
|
||||
transformers>=4.50.0
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft>=0.14.0
|
||||
sentencepiece
|
||||
220
examples/dreambooth/test_dreambooth_lora_hidream.py
Normal file
220
examples/dreambooth/test_dreambooth_lora_hidream.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-hidream-i1-pipe"
|
||||
text_encoder_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
tokenizer_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_hidream.py"
|
||||
transformer_layer_type = "double_stream_blocks.0.block.attn1.to_k"
|
||||
|
||||
def test_dreambooth_lora_hidream(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# `self.transformer_layer_type` should be in the state dict.
|
||||
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
resume_run_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
1721
examples/dreambooth/train_dreambooth_lora_hidream.py
Normal file
1721
examples/dreambooth/train_dreambooth_lora_hidream.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -77,6 +77,7 @@ if is_torch_available():
|
||||
"SanaLoraLoaderMixin",
|
||||
"Lumina2LoraLoaderMixin",
|
||||
"WanLoraLoaderMixin",
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
@@ -108,6 +109,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
|
||||
@@ -5360,6 +5360,325 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`HiDreamImageTransformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -56,6 +56,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -275,7 +275,14 @@ class HiDreamAttnProcessor:
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_routed_experts=4,
|
||||
num_activated_experts=2,
|
||||
aux_loss_alpha=0.01,
|
||||
_force_inference_output=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.top_k = num_activated_experts
|
||||
self.n_routed_experts = num_routed_experts
|
||||
@@ -289,9 +296,10 @@ class MoEGate(nn.Module):
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
|
||||
|
||||
self._force_inference_output = _force_inference_output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
# print(bsz, seq_len, h)
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
@@ -309,7 +317,7 @@ class MoEGate(nn.Module):
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
### expert-level computation auxiliary loss
|
||||
if self.training and self.alpha > 0.0:
|
||||
if self.training and self.alpha > 0.0 and not self._force_inference_output:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
# always compute aux loss based on the naive greedy topk method
|
||||
@@ -341,14 +349,19 @@ class MOEFeedForwardSwiGLU(nn.Module):
|
||||
hidden_dim: int,
|
||||
num_routed_experts: int,
|
||||
num_activated_experts: int,
|
||||
_force_inference_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2)
|
||||
self.experts = nn.ModuleList(
|
||||
[HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]
|
||||
)
|
||||
self._force_inference_output = _force_inference_output
|
||||
self.gate = MoEGate(
|
||||
embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts
|
||||
embed_dim=dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=_force_inference_output,
|
||||
)
|
||||
self.num_activated_experts = num_activated_experts
|
||||
|
||||
@@ -359,7 +372,7 @@ class MOEFeedForwardSwiGLU(nn.Module):
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
if self.training and not self._force_inference_output:
|
||||
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||
y = torch.empty_like(x, dtype=wtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
@@ -413,6 +426,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
_force_inference_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -436,6 +450,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
hidden_dim=4 * dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=_force_inference_output,
|
||||
)
|
||||
else:
|
||||
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
|
||||
@@ -480,6 +495,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
_force_inference_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -504,6 +520,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
hidden_dim=4 * dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=_force_inference_output,
|
||||
)
|
||||
else:
|
||||
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
|
||||
@@ -606,6 +623,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
axes_dims_rope: Tuple[int, int] = (32, 32),
|
||||
max_resolution: Tuple[int, int] = (128, 128),
|
||||
llama_layers: List[int] = None,
|
||||
force_inference_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
@@ -629,6 +647,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=force_inference_output,
|
||||
)
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
@@ -644,6 +663,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=force_inference_output,
|
||||
)
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
@@ -662,7 +682,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
|
||||
if is_training:
|
||||
if is_training and not self.config.force_inference_output:
|
||||
B, S, F = x.shape
|
||||
C = F // (self.config.patch_size * self.config.patch_size)
|
||||
x = (
|
||||
|
||||
@@ -13,6 +13,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import HiDreamImageLoraLoaderMixin
|
||||
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
|
||||
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
@@ -142,7 +143,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HiDreamImagePipeline(DiffusionPipeline):
|
||||
class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user