mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* feat: pipeline-level quant config.
Co-authored-by: SunMarc <marc.sun@hotmail.fr>
condition better.
support mapping.
improvements.
[Quantization] Add Quanto backend (#10756)
* update
* updaet
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* Update docs/source/en/quantization/quanto.md
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* Update src/diffusers/quantizers/quanto/utils.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* update
* update
---------
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
[Single File] Add single file loading for SANA Transformer (#10947)
* added support for from_single_file
* added diffusers mapping script
* added testcase
* bug fix
* updated tests
* corrected code quality
* corrected code quality
---------
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)
* updates
* updates
* updates
* updates
* notebooks revert
* fix-copies.
* seeing
* fix
* revert
* fixes
* fixes
* fixes
* remove print
* fix
* conflicts ii.
* updates
* fixes
* better filtering of prefix.
---------
Co-authored-by: hlky <hlky@hlky.ac>
[LoRA] CogView4 (#10981)
* update
* make fix-copies
* update
[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)
* memory usage tests
* fixes
* gguf
[`Research Project`] Add AnyText: Multilingual Visual Text Generation And Editing (#8998)
* Add initial template
* Second template
* feat: Add TextEmbeddingModule to AnyTextPipeline
* feat: Add AuxiliaryLatentModule template to AnyTextPipeline
* Add bert tokenizer from the anytext repo for now
* feat: Update AnyTextPipeline's modify_prompt method
This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe.
* Fill in the `forward` pass of `AuxiliaryLatentModule`
* `make style && make quality`
* `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library`
* Update error handling to raise and logging
* Add `create_glyph_lines` function into `TextEmbeddingModule`
* make style
* Up
* Up
* Up
* Up
* Remove several comments
* refactor: Remove ControlNetConditioningEmbedding and update code accordingly
* Up
* Up
* up
* refactor: Update AnyTextPipeline to include new optional parameters
* up
* feat: Add OCR model and its components
* chore: Update `TextEmbeddingModule` to include OCR model components and dependencies
* chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task
* `make style`
* refactor: Update `AnyTextPipeline`'s docstring
* Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once
* simplify
* `make style`
* Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function
* Simplify for now
* `make style`
* Up
* feat: Add scripts to convert AnyText controlnet to diffusers
* `make style`
* Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule`
* make style
* Up
* Simplify
* Up
* feat: Add safetensors module for loading model file
* Fix device issues
* Up
* Up
* refactor: Simplify
* refactor: Simplify code for loading models and handling data types
* `make style`
* refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule
* refactor: Update dtype in embedding_manager.py to match proj.weight
* Up
* Add attribution and adaptation information to pipeline_anytext.py
* Update usage example
* Will refactor `controlnet_cond_embedding` initialization
* Add `AnyTextControlNetConditioningEmbedding` template
* Refactor organization
* style
* style
* Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding`
* Follow one-file policy
* style
* [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel
* [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py
* [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py
* Refactor AnyTextControlNet to use configurable conditioning embedding channels
* Complete control net conditioning embedding in AnyTextControlNetModel
* up
* [FIX] Ensure embeddings use correct device in AnyTextControlNetModel
* up
* up
* style
* [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline
* [UPDATE] Update example code in anytext.py to use correct font file and improve clarity
* down
* [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing
* update pillow
* [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity
* [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file
* [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency
* 🆙
* style
* [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py
* style
* Update examples/research_projects/anytext/README.md
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
* Remove commented-out image preparation code in AnyTextPipeline
* Remove unnecessary blank line in README.md
[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6 (#11018)
* update
* update
* update
* update
* update
* update
* update
* update
* update
fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings (#11012)
small fix on generating time_ids & embeddings
[LoRA] support wan i2v loras from the world. (#11025)
* support wan i2v loras from the world.
* remove copied from.
* upates
* add lora.
Fix SD3 IPAdapter feature extractor (#11027)
chore: fix help messages in advanced diffusion examples (#10923)
Fix missing **kwargs in lora_pipeline.py (#11011)
* Update lora_pipeline.py
* Apply style fixes
* fix-copies
---------
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Fix for multi-GPU WAN inference (#10997)
Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs
Co-authored-by: Jimmy <39@🇺🇸.com>
[Refactor] Clean up import utils boilerplate (#11026)
* update
* update
* update
Use `output_size` in `repeat_interleave` (#11030)
[hybrid inference 🍯🐝] Add VAE encode (#11017)
* [hybrid inference 🍯🐝] Add VAE encode
* _toctree: add vae encode
* Add endpoints, tests
* vae_encode docs
* vae encode benchmarks
* api reference
* changelog
* Update docs/source/en/hybrid_inference/overview.md
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* update
---------
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)
* Wan Pipeline scaling fix, type hint warning, multi generator fix
* Apply suggestions from code review
[LoRA] change to warning from info when notifying the users about a LoRA no-op (#11044)
* move to warning.
* test related changes.
Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (#10827)
* Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline
---------
Co-authored-by: YiYi Xu <yixu310@gmail.com>
making ```formatted_images``` initialization compact (#10801)
compact writing
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)
* get_1d_rotary_pos_embed support npu
* Update src/diffusers/models/embeddings.py
---------
Co-authored-by: Kai zheng <kaizheng@KaideMacBook-Pro.local>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
[Tests] restrict memory tests for quanto for certain schemes. (#11052)
* restrict memory tests for quanto for certain schemes.
* Apply suggestions from code review
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
* fixes
* style
---------
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
[LoRA] feat: support non-diffusers wan t2v loras. (#11059)
feat: support non-diffusers wan t2v loras.
[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)
Fix: dtype mismatch of prompt embeddings in sd3 controlnet training
Co-authored-by: Andreas Jörg <andreasjoerg@MacBook-Pro-von-Andreas-2.fritz.box>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)
reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.
Co-authored-by: Juan Acevedo <jfacevedo@google.com>
Fix deterministic issue when getting pipeline dtype and device (#10696)
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
[Tests] add requires peft decorator. (#11037)
* add requires peft decorator.
* install peft conditionally.
* conditional deps.
Co-authored-by: DN6 <dhruv.nair@gmail.com>
---------
Co-authored-by: DN6 <dhruv.nair@gmail.com>
CogView4 Control Block (#10809)
* cogview4 control training
---------
Co-authored-by: OleehyO <leehy0357@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>
[CI] pin transformers version for benchmarking. (#11067)
pin transformers version for benchmarking.
updates
Fix Wan I2V Quality (#11087)
* fix_wan_i2v_quality
* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
* Update pipeline_wan_i2v.py
---------
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
LTX 0.9.5 (#10968)
* update
---------
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
make PR GPU tests conditioned on styling. (#11099)
Group offloading improvements (#11094)
update
Fix pipeline_flux_controlnet.py (#11095)
* Fix pipeline_flux_controlnet.py
* Fix style
update readme instructions. (#11096)
Co-authored-by: Juan Acevedo <jfacevedo@google.com>
Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)
Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP
Fix Group offloading behaviour when using streams (#11097)
* update
* update
Quality options in `export_to_video` (#11090)
* Quality options in `export_to_video`
* make style
improve more.
add placeholders for docstrings.
formatting.
smol fix.
solidify validation and annotation
* Revert "feat: pipeline-level quant config."
This reverts commit 316ff46b76.
* feat: implement pipeline-level quantization config
Co-authored-by: SunMarc <marc@huggingface.co>
* update
* fixes
* fix validation.
* add tests and other improvements.
* add tests
* import quality
* remove prints.
* add docs.
* fixes to docs.
* doc fixes.
* doc fixes.
* add validation to the input quantization_config.
* clarify recommendations.
* docs
* add to ci.
* todo.
---------
Co-authored-by: SunMarc <marc@huggingface.co>
191 lines
7.8 KiB
Python
191 lines
7.8 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 The HuggingFace Team Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a clone of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from diffusers import DiffusionPipeline, QuantoConfig
|
|
from diffusers.quantizers import PipelineQuantizationConfig
|
|
from diffusers.utils.testing_utils import (
|
|
is_transformers_available,
|
|
require_accelerate,
|
|
require_bitsandbytes_version_greater,
|
|
require_quanto,
|
|
require_torch,
|
|
require_torch_accelerator,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
|
|
if is_transformers_available():
|
|
from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig
|
|
else:
|
|
TranBitsAndBytesConfig = None
|
|
|
|
|
|
@require_bitsandbytes_version_greater("0.43.2")
|
|
@require_quanto
|
|
@require_accelerate
|
|
@require_torch
|
|
@require_torch_accelerator
|
|
@slow
|
|
class PipelineQuantizationTests(unittest.TestCase):
|
|
model_name = "hf-internal-testing/tiny-flux-pipe"
|
|
prompt = "a beautiful sunset amidst the mountains."
|
|
num_inference_steps = 10
|
|
seed = 0
|
|
|
|
def test_quant_config_set_correctly_through_kwargs(self):
|
|
components_to_quantize = ["transformer", "text_encoder_2"]
|
|
quant_config = PipelineQuantizationConfig(
|
|
quant_backend="bitsandbytes_4bit",
|
|
quant_kwargs={
|
|
"load_in_4bit": True,
|
|
"bnb_4bit_quant_type": "nf4",
|
|
"bnb_4bit_compute_dtype": torch.bfloat16,
|
|
},
|
|
components_to_quantize=components_to_quantize,
|
|
)
|
|
pipe = DiffusionPipeline.from_pretrained(
|
|
self.model_name,
|
|
quantization_config=quant_config,
|
|
torch_dtype=torch.bfloat16,
|
|
).to(torch_device)
|
|
for name, component in pipe.components.items():
|
|
if name in components_to_quantize:
|
|
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
|
|
quantization_config = component.config.quantization_config
|
|
self.assertTrue(quantization_config.load_in_4bit)
|
|
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
|
|
|
|
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
|
|
|
|
def test_quant_config_set_correctly_through_granular(self):
|
|
quant_config = PipelineQuantizationConfig(
|
|
quant_mapping={
|
|
"transformer": QuantoConfig(weights_dtype="int8"),
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
|
|
}
|
|
)
|
|
components_to_quantize = list(quant_config.quant_mapping.keys())
|
|
pipe = DiffusionPipeline.from_pretrained(
|
|
self.model_name,
|
|
quantization_config=quant_config,
|
|
torch_dtype=torch.bfloat16,
|
|
).to(torch_device)
|
|
for name, component in pipe.components.items():
|
|
if name in components_to_quantize:
|
|
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
|
|
quantization_config = component.config.quantization_config
|
|
|
|
if name == "text_encoder_2":
|
|
self.assertTrue(quantization_config.load_in_4bit)
|
|
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
|
|
else:
|
|
self.assertTrue(quantization_config.quant_method == "quanto")
|
|
|
|
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
|
|
|
|
def test_raises_error_for_invalid_config(self):
|
|
with self.assertRaises(ValueError) as err_context:
|
|
_ = PipelineQuantizationConfig(
|
|
quant_mapping={
|
|
"transformer": QuantoConfig(weights_dtype="int8"),
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
|
|
},
|
|
quant_backend="bitsandbytes_4bit",
|
|
)
|
|
|
|
self.assertTrue(
|
|
str(err_context.exception)
|
|
== "Both `quant_backend` and `quant_mapping` cannot be specified at the same time."
|
|
)
|
|
|
|
def test_validation_for_kwargs(self):
|
|
components_to_quantize = ["transformer", "text_encoder_2"]
|
|
with self.assertRaises(ValueError) as err_context:
|
|
_ = PipelineQuantizationConfig(
|
|
quant_backend="quanto",
|
|
quant_kwargs={"weights_dtype": "int8"},
|
|
components_to_quantize=components_to_quantize,
|
|
)
|
|
|
|
self.assertTrue(
|
|
"The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception)
|
|
)
|
|
|
|
def test_raises_error_for_wrong_config_class(self):
|
|
quant_config = {
|
|
"transformer": QuantoConfig(weights_dtype="int8"),
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
|
|
}
|
|
with self.assertRaises(ValueError) as err_context:
|
|
_ = DiffusionPipeline.from_pretrained(
|
|
self.model_name,
|
|
quantization_config=quant_config,
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
self.assertTrue(
|
|
str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`."
|
|
)
|
|
|
|
def test_validation_for_mapping(self):
|
|
with self.assertRaises(ValueError) as err_context:
|
|
_ = PipelineQuantizationConfig(
|
|
quant_mapping={
|
|
"transformer": DiffusionPipeline(),
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
|
|
}
|
|
)
|
|
|
|
self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception))
|
|
|
|
def test_saving_loading(self):
|
|
quant_config = PipelineQuantizationConfig(
|
|
quant_mapping={
|
|
"transformer": QuantoConfig(weights_dtype="int8"),
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
|
|
}
|
|
)
|
|
components_to_quantize = list(quant_config.quant_mapping.keys())
|
|
pipe = DiffusionPipeline.from_pretrained(
|
|
self.model_name,
|
|
quantization_config=quant_config,
|
|
torch_dtype=torch.bfloat16,
|
|
).to(torch_device)
|
|
|
|
pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"}
|
|
output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
pipe.save_pretrained(tmpdir)
|
|
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device)
|
|
for name, component in loaded_pipe.components.items():
|
|
if name in components_to_quantize:
|
|
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
|
|
quantization_config = component.config.quantization_config
|
|
|
|
if name == "text_encoder_2":
|
|
self.assertTrue(quantization_config.load_in_4bit)
|
|
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
|
|
else:
|
|
self.assertTrue(quantization_config.quant_method == "quanto")
|
|
|
|
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
|
|
|
|
self.assertTrue(torch.allclose(output_1, output_2))
|