1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into training-group-offloading-tests

This commit is contained in:
Sayak Paul
2025-05-12 12:58:40 +05:30
committed by GitHub
26 changed files with 907 additions and 148 deletions

View File

@@ -142,6 +142,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
@@ -525,6 +526,60 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_nightly_pipeline_level_quantization_tests:
name: Torch quantization nightly tests
strategy:
fail-fast: false
max-parallel: 2
runs-on:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "20gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install -U bitsandbytes optimum_quanto
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
- name: Pipeline-level quantization tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_pipeline_level_quant_torch_cuda \
--report-log=tests_pipeline_level_quant_torch_cuda.log \
tests/quantization/test_pipeline_level_quantization.py
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_cuda_pipeline_level_quant_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
# M1 runner currently not well supported
# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
# run_nightly_tests_apple_m1:

View File

@@ -13,9 +13,7 @@ specific language governing permissions and limitations under the License.
# Quantization
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).
Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference.
<Tip>
@@ -23,6 +21,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
</Tip>
## PipelineQuantizationConfig
[[autodoc]] quantizers.PipelineQuantizationConfig
## BitsAndBytesConfig

View File

@@ -39,3 +39,90 @@ Diffusers currently supports the following quantization methods.
- [Quanto](./quanto.md)
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
## Pipeline-level quantization
Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply
quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can
do this with [`~quantizers.PipelineQuantizationConfig`].
Start by defining a `PipelineQuantizationConfig`:
```py
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers.quantization_config import QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from transformers import BitsAndBytesConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": BitsAndBytesConfig(
load_in_4bit=True, compute_dtype=torch.bfloat16
),
}
)
```
Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference:
```py
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
).to("cuda")
image = pipe("photo of a cute dog").images[0]
```
This method allows for more granular control over the quantization specifications of individual
model-level components of a pipeline. It also allows for different quantization backends for
different components. In the above example, you used a combination of Quanto and BitsandBytes. However,
one caveat of this method is that users need to know which components come from `transformers` to be able
to import the right quantization config class.
The other method is simpler in terms of experience but is
less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way:
```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer", "text_encoder_2"],
)
```
This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example.
In this case, `quant_kwargs` will be used to initialize the quantization specifications
of the respective quantization configuration class of `quant_backend`. `components_to_quantize`
is used to denote the components that will be quantized. For most pipelines, you would want to
keep `transformer` in the list as that is often the most compute and memory intensive.
The config below will work for most diffusion pipelines that have a `transformer` component present.
In most case, you will want to quantize the `transformer` component as that is often the most compute-
intensive part of a diffusion pipeline.
```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer"],
)
```
Below is a list of the supported quantization backends available in both `diffusers` and `transformers`:
* `bitsandbytes_4bit`
* `bitsandbytes_8bit`
* `gguf`
* `quanto`
* `torchao`
Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's
recommended to quantize the text encoders that are memory-intensive. Some examples include T5,
Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through
`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`).

View File

@@ -1704,3 +1704,11 @@ def _convert_musubi_wan_lora_to_diffusers(state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
raise ValueError("Invalid LoRA state dict for HiDream.")
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict

View File

@@ -43,6 +43,7 @@ from .lora_conversion_utils import (
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
@@ -5371,7 +5372,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -5465,6 +5465,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
if is_non_diffusers_format:
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights

View File

@@ -152,9 +152,19 @@ class HunyuanVideoFramepackTransformer3DModel(
# 1. Latent and condition embedders
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
# Framepack history projection embedder
self.clean_x_embedder = None
if has_clean_x_embedder:
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
# Framepack image-conditioning embedder
self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
self.time_text_embed = HunyuanVideoConditionEmbedding(
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
)
@@ -186,14 +196,7 @@ class HunyuanVideoFramepackTransformer3DModel(
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
# Framepack specific modules
self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
self.clean_x_embedder = None
if has_clean_x_embedder:
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
self.use_gradient_checkpointing = False
self.gradient_checkpointing = False
def forward(
self,

View File

@@ -789,6 +789,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
latents = latents.to(self.vae.dtype)
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)

View File

@@ -675,8 +675,10 @@ def load_sub_model(
use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
quantization_config: Optional[Any] = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
from ..quantizers import PipelineQuantizationConfig
# retrieve class candidates
@@ -769,6 +771,17 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False
if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)
and issubclass(class_obj, torch.nn.Module)
):
model_quant_config = quantization_config._resolve_quant_config(
is_diffusers=is_diffusers_model, module_name=name
)
if model_quant_config is not None:
loading_kwargs["quantization_config"] = model_quant_config
# check if the module is in a subdirectory
if dduf_entries:
loading_kwargs["dduf_entries"] = dduf_entries

View File

@@ -47,6 +47,7 @@ from ..configuration_utils import ConfigMixin
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..quantizers import PipelineQuantizationConfig
from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
@@ -725,6 +726,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
@@ -741,6 +743,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
" install accelerate\n```\n."
)
if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
@@ -1001,6 +1006,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_safetensors=use_safetensors,
dduf_entries=dduf_entries,
provider_options=provider_options,
quantization_config=quantization_config,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."

View File

@@ -12,5 +12,183 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Dict, List, Optional, Union
from ..utils import is_transformers_available, logging
from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
try:
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
except ImportError:
class TransformersQuantConfigMixin:
pass
logger = logging.get_logger(__name__)
class PipelineQuantizationConfig:
"""
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
Args:
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
is available to both `diffusers` and `transformers`.
quant_kwargs (`dict`): Params to initialize the quantization backend class.
components_to_quantize (`list`): Components of a pipeline to be quantized.
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
and `components_to_quantize`.
"""
def __init__(
self,
quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
components_to_quantize: Optional[List[str]] = None,
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
):
self.quant_backend = quant_backend
# Initialize kwargs to be {} to set to the defaults.
self.quant_kwargs = quant_kwargs or {}
self.components_to_quantize = components_to_quantize
self.quant_mapping = quant_mapping
self.post_init()
def post_init(self):
quant_mapping = self.quant_mapping
self.is_granular = True if quant_mapping is not None else False
self._validate_init_args()
def _validate_init_args(self):
if self.quant_backend and self.quant_mapping:
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
if not self.quant_mapping and not self.quant_backend:
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
if not self.quant_kwargs and not self.quant_mapping:
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
if self.quant_backend is not None:
self._validate_init_kwargs_in_backends()
if self.quant_mapping is not None:
self._validate_quant_mapping_args()
def _validate_init_kwargs_in_backends(self):
quant_backend = self.quant_backend
self._check_backend_availability(quant_backend)
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
if quant_config_mapping_transformers is not None:
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
else:
init_kwargs_transformers = None
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
if init_kwargs_transformers != init_kwargs_diffusers:
raise ValueError(
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
"this mapping would look like."
)
def _validate_quant_mapping_args(self):
quant_mapping = self.quant_mapping
transformers_map, diffusers_map = self._get_quant_config_list()
available_transformers = list(transformers_map.values()) if transformers_map else None
available_diffusers = list(diffusers_map.values())
for module_name, config in quant_mapping.items():
if any(isinstance(config, cfg) for cfg in available_diffusers):
continue
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
continue
if available_transformers:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}; "
f"Available transformers configs: {available_transformers}."
)
else:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}."
)
def _check_backend_availability(self, quant_backend: str):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
available_backends_transformers = (
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
)
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
if (
available_backends_transformers and quant_backend not in available_backends_transformers
) or quant_backend not in quant_config_mapping_diffusers:
error_message = f"Provided quant_backend={quant_backend} was not found."
if available_backends_transformers:
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
raise ValueError(error_message)
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
quant_mapping = self.quant_mapping
components_to_quantize = self.components_to_quantize
# Granular case
if self.is_granular and module_name in quant_mapping:
logger.debug(f"Initializing quantization config class for {module_name}.")
config = quant_mapping[module_name]
return config
# Global config case
else:
should_quantize = False
# Only quantize the modules requested for.
if components_to_quantize and module_name in components_to_quantize:
should_quantize = True
# No specification for `components_to_quantize` means all modules should be quantized.
elif not self.is_granular and not components_to_quantize:
should_quantize = True
if should_quantize:
logger.debug(f"Initializing quantization config class for {module_name}.")
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
quant_config_cls = mapping_to_use[self.quant_backend]
quant_kwargs = self.quant_kwargs
return quant_config_cls(**quant_kwargs)
# Fallback: no applicable configuration found.
return None
def _get_quant_config_list(self):
if is_transformers_available():
from transformers.quantizers.auto import (
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
)
else:
quant_config_mapping_transformers = None
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
return quant_config_mapping_transformers, quant_config_mapping_diffusers

View File

@@ -75,7 +75,7 @@ class QuantizationConfigMixin:
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
return_unused_kwargs (`bool`,*optional*, defaults to `False`):
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
`PreTrainedModel`.
kwargs (`Dict[str, Any]`):

View File

@@ -38,6 +38,7 @@ from .import_utils import (
is_note_seq_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
is_peft_available,
is_timm_available,
is_torch_available,
@@ -486,6 +487,13 @@ def require_bitsandbytes(test_case):
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
def require_quanto(test_case):
"""
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
"""
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.

View File

@@ -62,7 +62,6 @@ from diffusers.utils.testing_utils import (
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
floats_tensor,
get_python_version,
is_torch_compile,
numpy_cosine_similarity_distance,
@@ -1778,7 +1777,7 @@ class TorchCompileTesterMixin:
@require_peft_backend
@require_peft_version_greater("0.14.0")
@is_torch_compile
class TestLoraHotSwappingForModel(unittest.TestCase):
class LoraHotSwappingForModelTesterMixin:
"""Test that hotswapping does not result in recompilation on the model directly.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
@@ -1799,48 +1798,24 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
gc.collect()
backend_empty_cache(torch_device)
def get_small_unet(self):
# from diffusers UNet2DConditionModelTests
torch.manual_seed(0)
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 8,
"attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
"sample_size": 16,
}
model = UNet2DConditionModel(**init_dict)
return model.to(torch_device)
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
def get_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py
from peft import LoraConfig
unet_lora_config = LoraConfig(
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config
return lora_config
def get_dummy_input(self):
# from UNet2DConditionModelTests
batch_size = 4
num_channels = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
def get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
"""
@@ -1858,23 +1833,27 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
fine.
"""
# create 2 adapters with different ranks and alphas
dummy_input = self.get_dummy_input()
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1)
unet = self.get_small_unet()
unet.add_adapter(lora_config0, adapter_name="adapter0")
model.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
output0_before = unet(**dummy_input)["sample"]
torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]
unet.add_adapter(lora_config1, adapter_name="adapter1")
unet.set_adapter("adapter1")
model.add_adapter(lora_config1, adapter_name="adapter1")
model.set_adapter("adapter1")
with torch.inference_mode():
output1_before = unet(**dummy_input)["sample"]
torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]
# sanity checks:
tol = 5e-3
@@ -1884,40 +1863,43 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dirname:
# save the adapter checkpoints
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del unet
model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del model
# load the first adapter
unet = self.get_small_unet()
torch.manual_seed(0)
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
unet.enable_lora_hotswap(target_rank=max_rank)
model.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
unet = torch.compile(unet, mode="reduce-overhead")
model = torch.compile(model, mode="reduce-overhead")
with torch.inference_mode():
output0_after = unet(**dummy_input)["sample"]
output0_after = model(**inputs_dict)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
output1_after = unet(**dummy_input)["sample"]
output1_after = model(**inputs_dict)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with self.assertRaisesRegex(ValueError, msg):
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_model(self, rank0, rank1):
@@ -1934,6 +1916,9 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
return
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True):
@@ -1941,52 +1926,77 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
return
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
target_modules = ["to_q"]
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
target_modules.append(self.get_linear_module_name_other_than_attn(model))
del model
# It's important to add this context to raise an error on recompilation
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with self.assertRaisesRegex(RuntimeError, msg):
unet.enable_lora_hotswap(target_rank=32)
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
from diffusers.loaders.peft import logger
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with self.assertLogs(logger=logger, level="WARNING") as cm:
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in log for log in cm.output)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Capture all warnings
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with self.assertRaisesRegex(ValueError, msg):
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self):
# check the error and log

View File

@@ -22,7 +22,7 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
@@ -78,7 +78,9 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict
class FluxTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class FluxTransformerTests(
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.

View File

@@ -0,0 +1,116 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import HunyuanVideoFramepackTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoFramepackTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.5, 0.7, 0.9]
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 3
height = 4
width = 4
text_encoder_embedding_dim = 16
image_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
image_embeds = torch.randn((batch_size, sequence_length, image_encoder_embedding_dim)).to(torch_device)
indices_latents = torch.ones((3,)).to(torch_device)
latents_clean = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
indices_latents_clean = torch.ones((num_frames - 1,)).to(torch_device)
latents_history_2x = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
indices_latents_history_2x = torch.ones((num_frames - 1,)).to(torch_device)
latents_history_4x = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to(
torch_device
)
indices_latents_history_4x = torch.ones(((num_frames - 1) * 4,)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
"image_embeds": image_embeds,
"indices_latents": indices_latents,
"latents_clean": latents_clean,
"indices_latents_clean": indices_latents_clean,
"latents_history_2x": latents_history_2x,
"indices_latents_history_2x": indices_latents_history_2x,
"latents_history_4x": latents_history_4x,
"indices_latents_history_4x": indices_latents_history_4x,
}
@property
def input_shape(self):
return (4, 3, 4, 4)
@property
def output_shape(self):
return (4, 3, 4, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 10,
"num_layers": 1,
"num_single_layers": 1,
"num_refiner_layers": 1,
"patch_size": 2,
"patch_size_t": 1,
"guidance_embeds": True,
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
"image_condition_type": None,
"has_image_proj": True,
"image_proj_dim": 16,
"has_clean_x_embedder": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoFramepackTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -53,7 +53,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin
if is_peft_available():
@@ -350,7 +350,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class UNet2DConditionModelTests(
ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.

View File

@@ -24,9 +24,10 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -316,19 +317,19 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow
@require_torch_gpu
@require_torch_accelerator
class ConsisIDPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_consisid(self):
generator = torch.Generator("cpu").manual_seed(0)
@@ -338,8 +339,8 @@ class ConsisIDPipelineIntegrationTests(unittest.TestCase):
prompt = self.prompt
image = load_image("https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true")
id_vit_hidden = [torch.ones([1, 2, 2])] * 1
id_cond = torch.ones(1, 2)
id_vit_hidden = [torch.ones([1, 577, 1024])] * 5
id_cond = torch.ones(1, 1280)
videos = pipe(
image=image,
@@ -357,5 +358,5 @@ class ConsisIDPipelineIntegrationTests(unittest.TestCase):
video = videos[0]
expected_video = torch.randn(1, 16, 480, 720, 3).numpy()
max_diff = numpy_cosine_similarity_distance(video, expected_video)
max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

View File

@@ -21,7 +21,15 @@ import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
require_torch_accelerator,
torch_device,
)
from ..pipeline_params import (
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS,
@@ -107,23 +115,23 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@nightly
@require_torch_gpu
@require_torch_accelerator
class DiTPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_dit_256(self):
generator = torch.manual_seed(0)
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
pipe.to("cuda")
pipe.to(torch_device)
words = ["vase", "umbrella", "white shark", "white wolf"]
ids = pipe.get_label_ids(words)
@@ -139,7 +147,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase):
def test_dit_512(self):
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
pipe.to(torch_device)
words = ["vase", "umbrella"]
ids = pipe.get_label_ids(words)
@@ -152,4 +160,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase):
f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy"
)
assert np.abs((expected_image - image).max()) < 1e-1
expected_slice = expected_image.flatten()
output_slice = image.flatten()
assert numpy_cosine_similarity_distance(expected_slice, output_slice) < 1e-2

View File

@@ -27,9 +27,10 @@ from diffusers import (
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -256,19 +257,19 @@ class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow
@require_torch_gpu
@require_torch_accelerator
class EasyAnimatePipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_EasyAnimate(self):
generator = torch.Generator("cpu").manual_seed(0)

View File

@@ -27,8 +27,8 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_torch_gpu,
require_big_accelerator,
require_torch_accelerator,
torch_device,
)
@@ -266,9 +266,9 @@ class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unitte
@nightly
@require_torch_gpu
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_torch_accelerator
@require_big_accelerator
@pytest.mark.big_accelerator
class MochiPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
@@ -302,5 +302,5 @@ class MochiPipelineIntegrationTests(unittest.TestCase):
video = videos[0]
expected_video = torch.randn(1, 19, 480, 848, 3).numpy()
max_diff = numpy_cosine_similarity_distance(video, expected_video)
max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

View File

@@ -7,8 +7,10 @@ from transformers import AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -87,7 +89,7 @@ class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@slow
@require_torch_gpu
@require_torch_accelerator
class OmniGenPipelineSlowTests(unittest.TestCase):
pipeline_class = OmniGenPipeline
repo_id = "shitao/OmniGen-v1-diffusers"
@@ -95,12 +97,12 @@ class OmniGenPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
@@ -125,21 +127,56 @@ class OmniGenPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
[0.1783447, 0.16772744, 0.14339337],
[0.17066911, 0.15521264, 0.13757327],
[0.17072496, 0.15531206, 0.13524258],
[0.16746324, 0.1564025, 0.13794944],
[0.16490817, 0.15258026, 0.13697758],
[0.16971767, 0.15826806, 0.13928896],
[0.16782972, 0.15547255, 0.13783783],
[0.16464645, 0.15281534, 0.13522372],
[0.16535294, 0.15301755, 0.13526791],
[0.16365296, 0.15092957, 0.13443318],
],
dtype=np.float32,
expected_slices = Expectations(
{
("xpu", 3): np.array(
[
[0.05859375, 0.05859375, 0.04492188],
[0.04882812, 0.04101562, 0.03320312],
[0.04882812, 0.04296875, 0.03125],
[0.04296875, 0.0390625, 0.03320312],
[0.04296875, 0.03710938, 0.03125],
[0.04492188, 0.0390625, 0.03320312],
[0.04296875, 0.03710938, 0.03125],
[0.04101562, 0.03710938, 0.02734375],
[0.04101562, 0.03515625, 0.02734375],
[0.04101562, 0.03515625, 0.02929688],
],
dtype=np.float32,
),
("cuda", 7): np.array(
[
[0.1783447, 0.16772744, 0.14339337],
[0.17066911, 0.15521264, 0.13757327],
[0.17072496, 0.15531206, 0.13524258],
[0.16746324, 0.1564025, 0.13794944],
[0.16490817, 0.15258026, 0.13697758],
[0.16971767, 0.15826806, 0.13928896],
[0.16782972, 0.15547255, 0.13783783],
[0.16464645, 0.15281534, 0.13522372],
[0.16535294, 0.15301755, 0.13526791],
[0.16365296, 0.15092957, 0.13443318],
],
dtype=np.float32,
),
("cuda", 8): np.array(
[
[0.0546875, 0.05664062, 0.04296875],
[0.046875, 0.04101562, 0.03320312],
[0.05078125, 0.04296875, 0.03125],
[0.04296875, 0.04101562, 0.03320312],
[0.0390625, 0.03710938, 0.02929688],
[0.04296875, 0.03710938, 0.03125],
[0.0390625, 0.03710938, 0.02929688],
[0.0390625, 0.03710938, 0.02734375],
[0.0390625, 0.03320312, 0.02734375],
[0.0390625, 0.03320312, 0.02734375],
],
dtype=np.float32,
),
}
)
expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())

View File

@@ -25,11 +25,12 @@ from transformers import CLIPImageProcessor, CLIPVisionConfig
from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
nightly,
require_torch_gpu,
require_torch_accelerator,
torch_device,
)
@@ -174,19 +175,19 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@nightly
@require_torch_gpu
@require_torch_accelerator
class PaintByExamplePipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_paint_by_example(self):
# make sure here that pndm scheduler skips prk

View File

@@ -32,7 +32,14 @@ from diffusers import (
StableAudioProjectionModel,
)
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
nightly,
require_torch_accelerator,
torch_device,
)
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -419,17 +426,17 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@nightly
@require_torch_gpu
@require_torch_accelerator
class StableAudioPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -459,9 +466,15 @@ class StableAudioPipelineIntegrationTests(unittest.TestCase):
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[0, 447590:447600]
# fmt: off
expected_slice = np.array(
[-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]
expected_slices = Expectations(
{
("xpu", 3): np.array([-0.0285, 0.1083, 0.1863, 0.3165, 0.5312, 0.6971, 0.6958, 0.6177, 0.5598, 0.5048]),
("cuda", 7): np.array([-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]),
("cuda", 8): np.array([-0.0285, 0.1082, 0.1862, 0.3163, 0.5306, 0.6964, 0.6953, 0.6172, 0.5593, 0.5044]),
}
)
# fmt: one
# fmt: on
expected_slice = expected_slices.get_expectation()
max_diff = np.abs(expected_slice - audio_slice.detach().cpu().numpy()).max()
assert max_diff < 1.5e-3

View File

@@ -389,7 +389,7 @@ class BnB4BitBasicTests(Base4bitTests):
class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
@@ -657,7 +657,7 @@ class SlowBnb4BitTests(Base4bitTests):
class SlowBnb4BitFluxTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
@@ -674,7 +674,7 @@ class SlowBnb4BitFluxTests(Base4bitTests):
del self.pipeline_4bit
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_quality(self):
# keep the resolution and max tokens to a lower number for faster execution.
@@ -722,7 +722,7 @@ class SlowBnb4BitFluxTests(Base4bitTests):
class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
self.pipeline_4bit.enable_model_cpu_offload()
@@ -731,7 +731,7 @@ class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
del self.pipeline_4bit
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

View File

@@ -0,0 +1,190 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import torch
from diffusers import DiffusionPipeline, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.testing_utils import (
is_transformers_available,
require_accelerate,
require_bitsandbytes_version_greater,
require_quanto,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
if is_transformers_available():
from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig
else:
TranBitsAndBytesConfig = None
@require_bitsandbytes_version_greater("0.43.2")
@require_quanto
@require_accelerate
@require_torch
@require_torch_accelerator
@slow
class PipelineQuantizationTests(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
def test_quant_config_set_correctly_through_kwargs(self):
components_to_quantize = ["transformer", "text_encoder_2"]
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=components_to_quantize,
)
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for name, component in pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
def test_quant_config_set_correctly_through_granular(self):
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
components_to_quantize = list(quant_config.quant_mapping.keys())
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for name, component in pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
if name == "text_encoder_2":
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
else:
self.assertTrue(quantization_config.quant_method == "quanto")
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
def test_raises_error_for_invalid_config(self):
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
},
quant_backend="bitsandbytes_4bit",
)
self.assertTrue(
str(err_context.exception)
== "Both `quant_backend` and `quant_mapping` cannot be specified at the same time."
)
def test_validation_for_kwargs(self):
components_to_quantize = ["transformer", "text_encoder_2"]
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_backend="quanto",
quant_kwargs={"weights_dtype": "int8"},
components_to_quantize=components_to_quantize,
)
self.assertTrue(
"The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception)
)
def test_raises_error_for_wrong_config_class(self):
quant_config = {
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
with self.assertRaises(ValueError) as err_context:
_ = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
self.assertTrue(
str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`."
)
def test_validation_for_mapping(self):
with self.assertRaises(ValueError) as err_context:
_ = PipelineQuantizationConfig(
quant_mapping={
"transformer": DiffusionPipeline(),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception))
def test_saving_loading(self):
quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
}
)
components_to_quantize = list(quant_config.quant_mapping.keys())
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"}
output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device)
for name, component in loaded_pipe.components.items():
if name in components_to_quantize:
self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
quantization_config = component.config.quantization_config
if name == "text_encoder_2":
self.assertTrue(quantization_config.load_in_4bit)
self.assertTrue(quantization_config.quant_method == "bitsandbytes")
else:
self.assertTrue(quantization_config.quant_method == "quanto")
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
self.assertTrue(torch.allclose(output_1, output_2))

View File

@@ -34,13 +34,24 @@ try:
print("Torch version:", torch.__version__)
print("Cuda available:", torch.cuda.is_available())
print("Cuda version:", torch.version.cuda)
print("CuDNN version:", torch.backends.cudnn.version())
print("Number of GPUs available:", torch.cuda.device_count())
if torch.cuda.is_available():
print("Cuda version:", torch.version.cuda)
print("CuDNN version:", torch.backends.cudnn.version())
print("Number of GPUs available:", torch.cuda.device_count())
device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
print(f"CUDA memory: {total_memory} GB")
print("XPU available:", hasattr(torch, "xpu") and torch.xpu.is_available())
if hasattr(torch, "xpu") and torch.xpu.is_available():
print("XPU model:", torch.xpu.get_device_properties(0).name)
print("XPU compiler version:", torch.version.xpu)
print("Number of XPUs available:", torch.xpu.device_count())
device_properties = torch.xpu.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
print(f"XPU memory: {total_memory} GB")
except ImportError:
print("Torch version:", None)