1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/tests/quantization/test_pipeline_level_quantization.py
Sayak Paul 599c887164 feat: pipeline-level quantization config (#11130)
* feat: pipeline-level quant config.

Co-authored-by: SunMarc <marc.sun@hotmail.fr>

condition better.

support mapping.

improvements.

[Quantization] Add Quanto backend (#10756)

* update

* updaet

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/quanto.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/quanto/utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

[Single File] Add single file loading for SANA Transformer (#10947)

* added support for from_single_file

* added diffusers mapping script

* added testcase

* bug fix

* updated tests

* corrected code quality

* corrected code quality

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)

* updates

* updates

* updates

* updates

* notebooks revert

* fix-copies.

* seeing

* fix

* revert

* fixes

* fixes

* fixes

* remove print

* fix

* conflicts ii.

* updates

* fixes

* better filtering of prefix.

---------

Co-authored-by: hlky <hlky@hlky.ac>

[LoRA] CogView4 (#10981)

* update

* make fix-copies

* update

[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)

* memory usage tests

* fixes

* gguf

[`Research Project`] Add AnyText: Multilingual Visual Text Generation And Editing (#8998)

* Add initial template

* Second template

* feat: Add TextEmbeddingModule to AnyTextPipeline

* feat: Add AuxiliaryLatentModule template to AnyTextPipeline

* Add bert tokenizer from the anytext repo for now

* feat: Update AnyTextPipeline's modify_prompt method

This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe.

* Fill in the `forward` pass of `AuxiliaryLatentModule`

* `make style && make quality`

* `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library`

* Update error handling to raise and logging

* Add `create_glyph_lines` function into `TextEmbeddingModule`

* make style

* Up

* Up

* Up

* Up

* Remove several comments

* refactor: Remove ControlNetConditioningEmbedding and update code accordingly

* Up

* Up

* up

* refactor: Update AnyTextPipeline to include new optional parameters

* up

* feat: Add OCR model and its components

* chore: Update `TextEmbeddingModule` to include OCR model components and dependencies

* chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task

* `make style`

* refactor: Update `AnyTextPipeline`'s docstring

* Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once

* simplify

* `make style`

* Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function

* Simplify for now

* `make style`

* Up

* feat: Add scripts to convert AnyText controlnet to diffusers

* `make style`

* Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule`

* make style

* Up

* Simplify

* Up

* feat: Add safetensors module for loading model file

* Fix device issues

* Up

* Up

* refactor: Simplify

* refactor: Simplify code for loading models and handling data types

* `make style`

* refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule

* refactor: Update dtype in embedding_manager.py to match proj.weight

* Up

* Add attribution and adaptation information to pipeline_anytext.py

* Update usage example

* Will refactor `controlnet_cond_embedding` initialization

* Add `AnyTextControlNetConditioningEmbedding` template

* Refactor organization

* style

* style

* Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding`

* Follow one-file policy

* style

* [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel

* [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py

* [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py

* Refactor AnyTextControlNet to use configurable conditioning embedding channels

* Complete control net conditioning embedding in AnyTextControlNetModel

* up

* [FIX] Ensure embeddings use correct device in AnyTextControlNetModel

* up

* up

* style

* [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline

* [UPDATE] Update example code in anytext.py to use correct font file and improve clarity

* down

* [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing

* update pillow

* [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity

* [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file

* [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency

* 🆙

* style

* [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py

* style

* Update examples/research_projects/anytext/README.md

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Remove commented-out image preparation code in AnyTextPipeline

* Remove unnecessary blank line in README.md

[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6  (#11018)

* update

* update

* update

* update

* update

* update

* update

* update

* update

fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings  (#11012)

small fix on generating time_ids & embeddings

[LoRA] support wan i2v loras from the world. (#11025)

* support wan i2v loras from the world.

* remove copied from.

* upates

* add lora.

Fix SD3 IPAdapter feature extractor (#11027)

chore: fix help messages in advanced diffusion examples (#10923)

Fix missing **kwargs in lora_pipeline.py (#11011)

* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

Fix for multi-GPU WAN inference (#10997)

Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs

Co-authored-by: Jimmy <39@🇺🇸.com>

[Refactor] Clean up import utils boilerplate (#11026)

* update

* update

* update

Use `output_size` in `repeat_interleave` (#11030)

[hybrid inference 🍯🐝] Add VAE encode (#11017)

* [hybrid inference 🍯🐝] Add VAE encode

* _toctree: add vae encode

* Add endpoints, tests

* vae_encode docs

* vae encode benchmarks

* api reference

* changelog

* Update docs/source/en/hybrid_inference/overview.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)

* Wan Pipeline scaling fix, type hint warning, multi generator fix

* Apply suggestions from code review

[LoRA] change to warning from info when notifying the users about a LoRA no-op (#11044)

* move to warning.

* test related changes.

Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (#10827)

* Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>

making ```formatted_images``` initialization compact (#10801)

compact writing

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>

Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)

* get_1d_rotary_pos_embed support npu

* Update src/diffusers/models/embeddings.py

---------

Co-authored-by: Kai zheng <kaizheng@KaideMacBook-Pro.local>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>

[Tests] restrict memory tests for quanto for certain schemes. (#11052)

* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* fixes

* style

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[LoRA] feat: support non-diffusers wan t2v loras. (#11059)

feat: support non-diffusers wan t2v loras.

[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)

Fix: dtype mismatch of prompt embeddings in sd3 controlnet training

Co-authored-by: Andreas Jörg <andreasjoerg@MacBook-Pro-von-Andreas-2.fritz.box>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)

reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.

Co-authored-by: Juan Acevedo <jfacevedo@google.com>

Fix deterministic issue when getting pipeline dtype and device (#10696)

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[Tests] add requires peft decorator. (#11037)

* add requires peft decorator.

* install peft conditionally.

* conditional deps.

Co-authored-by: DN6 <dhruv.nair@gmail.com>

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>

CogView4 Control Block (#10809)

* cogview4 control training

---------

Co-authored-by: OleehyO <leehy0357@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>

[CI] pin transformers version for benchmarking. (#11067)

pin transformers version for benchmarking.

updates

Fix Wan I2V Quality (#11087)

* fix_wan_i2v_quality

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update pipeline_wan_i2v.py

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>

LTX 0.9.5 (#10968)

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>

make PR GPU tests conditioned on styling. (#11099)

Group offloading improvements (#11094)

update

Fix pipeline_flux_controlnet.py (#11095)

* Fix pipeline_flux_controlnet.py

* Fix style

update readme instructions. (#11096)

Co-authored-by: Juan Acevedo <jfacevedo@google.com>

Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)

Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP

Fix Group offloading behaviour when using streams (#11097)

* update

* update

Quality options in `export_to_video` (#11090)

* Quality options in `export_to_video`

* make style

improve more.

add placeholders for docstrings.

formatting.

smol fix.

solidify validation and annotation

* Revert "feat: pipeline-level quant config."

This reverts commit 316ff46b76.

* feat: implement pipeline-level quantization config

Co-authored-by: SunMarc <marc@huggingface.co>

* update

* fixes

* fix validation.

* add tests and other improvements.

* add tests

* import quality

* remove prints.

* add docs.

* fixes to docs.

* doc fixes.

* doc fixes.

* add validation to the input quantization_config.

* clarify recommendations.

* docs

* add to ci.

* todo.

---------

Co-authored-by: SunMarc <marc@huggingface.co>
2025-05-09 10:04:44 +05:30

191 lines
7.8 KiB
Python

# coding=utf-8
# Copyright 2024 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import torch
from diffusers import DiffusionPipeline, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.testing_utils import (
is_transformers_available,
require_accelerate,
require_bitsandbytes_version_greater,
require_quanto,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
if is_transformers_available():
from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig
else:
TranBitsAndBytesConfig = None
@require_bitsandbytes_version_greater("0.43.2")
@require_quanto
@require_accelerate
@require_torch
@require_torch_accelerator
@slow
class PipelineQuantizationTests(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
def test_quant_config_set_correctly_through_kwargs(self):
components_to_quantize = ["transformer", "text_encoder_2"]
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=components_to_quantize,
)
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for name, component in pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
def test_quant_config_set_correctly_through_granular(self):
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
components_to_quantize = list(quant_config.quant_mapping.keys())
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for name, component in pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
if name == "text_encoder_2":
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
else:
self.assertTrue(quantization_config.quant_method == "quanto")
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
def test_raises_error_for_invalid_config(self):
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
},
quant_backend="bitsandbytes_4bit",
)
self.assertTrue(
str(err_context.exception)
== "Both `quant_backend` and `quant_mapping` cannot be specified at the same time."
)
def test_validation_for_kwargs(self):
components_to_quantize = ["transformer", "text_encoder_2"]
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_backend="quanto",
quant_kwargs={"weights_dtype": "int8"},
components_to_quantize=components_to_quantize,
)
self.assertTrue(
"The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception)
)
def test_raises_error_for_wrong_config_class(self):
quant_config = {
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
with self.assertRaises(ValueError) as err_context:
_ = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(
str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`."
)
def test_validation_for_mapping(self):
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_mapping={
"transformer": DiffusionPipeline(),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception))
def test_saving_loading(self):
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
components_to_quantize = list(quant_config.quant_mapping.keys())
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"}
output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device)
for name, component in loaded_pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
if name == "text_encoder_2":
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
else:
self.assertTrue(quantization_config.quant_method == "quanto")
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
self.assertTrue(torch.allclose(output_1, output_2))