diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 7696852ecd..4f92717df8 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -142,6 +142,7 @@ jobs: HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms CUBLAS_WORKSPACE_CONFIG: :16:8 + RUN_COMPILE: yes run: | python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "not Flax and not Onnx" \ @@ -525,6 +526,60 @@ jobs: pip install slack_sdk tabulate python utils/log_reports.py >> $GITHUB_STEP_SUMMARY + run_nightly_pipeline_level_quantization_tests: + name: Torch quantization nightly tests + strategy: + fail-fast: false + max-parallel: 2 + runs-on: + group: aws-g6e-xlarge-plus + container: + image: diffusers/diffusers-pytorch-cuda + options: --shm-size "20gb" --ipc host --gpus 0 + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: NVIDIA-SMI + run: nvidia-smi + - name: Install dependencies + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m uv pip install -e [quality,test] + python -m uv pip install -U bitsandbytes optimum_quanto + python -m uv pip install pytest-reportlog + - name: Environment + run: | + python utils/print_env.py + - name: Pipeline-level quantization tests on GPU + env: + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} + # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms + CUBLAS_WORKSPACE_CONFIG: :16:8 + BIG_GPU_MEMORY: 40 + run: | + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ + --make-reports=tests_pipeline_level_quant_torch_cuda \ + --report-log=tests_pipeline_level_quant_torch_cuda.log \ + tests/quantization/test_pipeline_level_quantization.py + - name: Failure short reports + if: ${{ failure() }} + run: | + cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt + cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: torch_cuda_pipeline_level_quant_reports + path: reports + - name: Generate Report and Notify Channel + if: always() + run: | + pip install slack_sdk tabulate + python utils/log_reports.py >> $GITHUB_STEP_SUMMARY + # M1 runner currently not well supported # TODO: (Dhruv) add these back when we setup better testing for Apple Silicon # run_nightly_tests_apple_m1: diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 2c728cff3c..e2ca990190 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -13,9 +13,7 @@ specific language governing permissions and limitations under the License. # Quantization -Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index). - -Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class. +Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. @@ -23,6 +21,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui +## PipelineQuantizationConfig + +[[autodoc]] quantizers.PipelineQuantizationConfig ## BitsAndBytesConfig diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 93323f86c7..68b99f524e 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -39,3 +39,90 @@ Diffusers currently supports the following quantization methods. - [Quanto](./quanto.md) [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. + +## Pipeline-level quantization + +Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply +quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can +do this with [`~quantizers.PipelineQuantizationConfig`]. + +Start by defining a `PipelineQuantizationConfig`: + +```py +import torch +from diffusers import DiffusionPipeline +from diffusers.quantizers.quantization_config import QuantoConfig +from diffusers.quantizers import PipelineQuantizationConfig +from transformers import BitsAndBytesConfig + +pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": BitsAndBytesConfig( + load_in_4bit=True, compute_dtype=torch.bfloat16 + ), + } +) +``` + +Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference: + +```py +pipe = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantization_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, +).to("cuda") + +image = pipe("photo of a cute dog").images[0] +``` + +This method allows for more granular control over the quantization specifications of individual +model-level components of a pipeline. It also allows for different quantization backends for +different components. In the above example, you used a combination of Quanto and BitsandBytes. However, +one caveat of this method is that users need to know which components come from `transformers` to be able +to import the right quantization config class. + +The other method is simpler in terms of experience but is +less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way: + +```py +pipeline_quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16}, + components_to_quantize=["transformer", "text_encoder_2"], +) +``` + +This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example. + +In this case, `quant_kwargs` will be used to initialize the quantization specifications +of the respective quantization configuration class of `quant_backend`. `components_to_quantize` +is used to denote the components that will be quantized. For most pipelines, you would want to +keep `transformer` in the list as that is often the most compute and memory intensive. + +The config below will work for most diffusion pipelines that have a `transformer` component present. +In most case, you will want to quantize the `transformer` component as that is often the most compute- +intensive part of a diffusion pipeline. + +```py +pipeline_quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16}, + components_to_quantize=["transformer"], +) +``` + +Below is a list of the supported quantization backends available in both `diffusers` and `transformers`: + +* `bitsandbytes_4bit` +* `bitsandbytes_8bit` +* `gguf` +* `quanto` +* `torchao` + + +Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's +recommended to quantize the text encoders that are memory-intensive. Some examples include T5, +Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through +`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`). \ No newline at end of file diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index d5fa7dcfc3..af7547f451 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1704,3 +1704,11 @@ def _convert_musubi_wan_lora_to_diffusers(state_dict): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict + + +def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"): + if not all(k.startswith(non_diffusers_prefix) for k in state_dict): + raise ValueError("Invalid LoRA state dict for HiDream.") + converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} + converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} + return converted_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 810ec8adb1..b68cf72b91 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -43,6 +43,7 @@ from .lora_conversion_utils import ( _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, + _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, _convert_non_diffusers_wan_lora_to_diffusers, @@ -5371,7 +5372,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -5465,6 +5465,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + is_non_diffusers_format = any("diffusion_model" in k for k in state_dict) + if is_non_diffusers_format: + state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict) + return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index 58b8115694..c2eb7fd2a7 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -152,9 +152,19 @@ class HunyuanVideoFramepackTransformer3DModel( # 1. Latent and condition embedders self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + + # Framepack history projection embedder + self.clean_x_embedder = None + if has_clean_x_embedder: + self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) + + # Framepack image-conditioning embedder + self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None + self.time_text_embed = HunyuanVideoConditionEmbedding( inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type ) @@ -186,14 +196,7 @@ class HunyuanVideoFramepackTransformer3DModel( self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) - # Framepack specific modules - self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None - - self.clean_x_embedder = None - if has_clean_x_embedder: - self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim) - - self.use_gradient_checkpointing = False + self.gradient_checkpointing = False def forward( self, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 6f3faed8ff..71ec58cdfd 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -789,6 +789,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi ] latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + latents = latents.to(self.vae.dtype) video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 9788d758e9..3404ae5130 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -675,8 +675,10 @@ def load_sub_model( use_safetensors: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, + quantization_config: Optional[Any] = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" + from ..quantizers import PipelineQuantizationConfig # retrieve class candidates @@ -769,6 +771,17 @@ def load_sub_model( else: loading_kwargs["low_cpu_mem_usage"] = False + if ( + quantization_config is not None + and isinstance(quantization_config, PipelineQuantizationConfig) + and issubclass(class_obj, torch.nn.Module) + ): + model_quant_config = quantization_config._resolve_quant_config( + is_diffusers=is_diffusers_model, module_name=name + ) + if model_quant_config is not None: + loading_kwargs["quantization_config"] = model_quant_config + # check if the module is in a subdirectory if dduf_entries: loading_kwargs["dduf_entries"] = dduf_entries diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 3be3e46ca4..7cb2a12d3c 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -47,6 +47,7 @@ from ..configuration_utils import ConfigMixin from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin +from ..quantizers import PipelineQuantizationConfig from ..quantizers.bitsandbytes.utils import _check_bnb_status from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( @@ -725,6 +726,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + quantization_config = kwargs.pop("quantization_config", None) if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 @@ -741,6 +743,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): " install accelerate\n```\n." ) + if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig): + raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.") + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" @@ -1001,6 +1006,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): use_safetensors=use_safetensors, dduf_entries=dduf_entries, provider_options=provider_options, + quantization_config=quantization_config, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 4c8483a3d6..bd9e2303c9 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -12,5 +12,183 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from typing import Dict, List, Optional, Union + +from ..utils import is_transformers_available, logging from .auto import DiffusersAutoQuantizer from .base import DiffusersQuantizer +from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin + + +try: + from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin +except ImportError: + + class TransformersQuantConfigMixin: + pass + + +logger = logging.get_logger(__name__) + + +class PipelineQuantizationConfig: + """ + Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`]. + + Args: + quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend + is available to both `diffusers` and `transformers`. + quant_kwargs (`dict`): Params to initialize the quantization backend class. + components_to_quantize (`list`): Components of a pipeline to be quantized. + quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline + components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`, + and `components_to_quantize`. + """ + + def __init__( + self, + quant_backend: str = None, + quant_kwargs: Dict[str, Union[str, float, int, dict]] = None, + components_to_quantize: Optional[List[str]] = None, + quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, + ): + self.quant_backend = quant_backend + # Initialize kwargs to be {} to set to the defaults. + self.quant_kwargs = quant_kwargs or {} + self.components_to_quantize = components_to_quantize + self.quant_mapping = quant_mapping + + self.post_init() + + def post_init(self): + quant_mapping = self.quant_mapping + self.is_granular = True if quant_mapping is not None else False + + self._validate_init_args() + + def _validate_init_args(self): + if self.quant_backend and self.quant_mapping: + raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.") + + if not self.quant_mapping and not self.quant_backend: + raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") + + if not self.quant_kwargs and not self.quant_mapping: + raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") + + if self.quant_backend is not None: + self._validate_init_kwargs_in_backends() + + if self.quant_mapping is not None: + self._validate_quant_mapping_args() + + def _validate_init_kwargs_in_backends(self): + quant_backend = self.quant_backend + + self._check_backend_availability(quant_backend) + + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + if quant_config_mapping_transformers is not None: + init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) + init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} + else: + init_kwargs_transformers = None + + init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) + init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} + + if init_kwargs_transformers != init_kwargs_diffusers: + raise ValueError( + "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " + f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how " + "this mapping would look like." + ) + + def _validate_quant_mapping_args(self): + quant_mapping = self.quant_mapping + transformers_map, diffusers_map = self._get_quant_config_list() + + available_transformers = list(transformers_map.values()) if transformers_map else None + available_diffusers = list(diffusers_map.values()) + + for module_name, config in quant_mapping.items(): + if any(isinstance(config, cfg) for cfg in available_diffusers): + continue + + if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers): + continue + + if available_transformers: + raise ValueError( + f"Provided config for module_name={module_name} could not be found. " + f"Available diffusers configs: {available_diffusers}; " + f"Available transformers configs: {available_transformers}." + ) + else: + raise ValueError( + f"Provided config for module_name={module_name} could not be found. " + f"Available diffusers configs: {available_diffusers}." + ) + + def _check_backend_availability(self, quant_backend: str): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + available_backends_transformers = ( + list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None + ) + available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) + + if ( + available_backends_transformers and quant_backend not in available_backends_transformers + ) or quant_backend not in quant_config_mapping_diffusers: + error_message = f"Provided quant_backend={quant_backend} was not found." + if available_backends_transformers: + error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." + error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." + raise ValueError(error_message) + + def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + quant_mapping = self.quant_mapping + components_to_quantize = self.components_to_quantize + + # Granular case + if self.is_granular and module_name in quant_mapping: + logger.debug(f"Initializing quantization config class for {module_name}.") + config = quant_mapping[module_name] + return config + + # Global config case + else: + should_quantize = False + # Only quantize the modules requested for. + if components_to_quantize and module_name in components_to_quantize: + should_quantize = True + # No specification for `components_to_quantize` means all modules should be quantized. + elif not self.is_granular and not components_to_quantize: + should_quantize = True + + if should_quantize: + logger.debug(f"Initializing quantization config class for {module_name}.") + mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers + quant_config_cls = mapping_to_use[self.quant_backend] + quant_kwargs = self.quant_kwargs + return quant_config_cls(**quant_kwargs) + + # Fallback: no applicable configuration found. + return None + + def _get_quant_config_list(self): + if is_transformers_available(): + from transformers.quantizers.auto import ( + AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, + ) + else: + quant_config_mapping_transformers = None + + from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers + + return quant_config_mapping_transformers, quant_config_mapping_diffusers diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index cc4d4fc1b0..7fec68642f 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -75,7 +75,7 @@ class QuantizationConfigMixin: Args: config_dict (`Dict[str, Any]`): Dictionary that will be used to instantiate the configuration object. - return_unused_kwargs (`bool`,*optional*, defaults to `False`): + return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in `PreTrainedModel`. kwargs (`Dict[str, Any]`): diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7a524e76f1..00aad9d71a 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -38,6 +38,7 @@ from .import_utils import ( is_note_seq_available, is_onnx_available, is_opencv_available, + is_optimum_quanto_available, is_peft_available, is_timm_available, is_torch_available, @@ -486,6 +487,13 @@ def require_bitsandbytes(test_case): return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) +def require_quanto(test_case): + """ + Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed. + """ + return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case) + + def require_accelerate(test_case): """ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 97118c98f4..4602c042d7 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -62,7 +62,6 @@ from diffusers.utils.testing_utils import ( backend_max_memory_allocated, backend_reset_peak_memory_stats, backend_synchronize, - floats_tensor, get_python_version, is_torch_compile, numpy_cosine_similarity_distance, @@ -1778,7 +1777,7 @@ class TorchCompileTesterMixin: @require_peft_backend @require_peft_version_greater("0.14.0") @is_torch_compile -class TestLoraHotSwappingForModel(unittest.TestCase): +class LoraHotSwappingForModelTesterMixin: """Test that hotswapping does not result in recompilation on the model directly. We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively @@ -1799,48 +1798,24 @@ class TestLoraHotSwappingForModel(unittest.TestCase): gc.collect() backend_empty_cache(torch_device) - def get_small_unet(self): - # from diffusers UNet2DConditionModelTests - torch.manual_seed(0) - init_dict = { - "block_out_channels": (4, 8), - "norm_num_groups": 4, - "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), - "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), - "cross_attention_dim": 8, - "attention_head_dim": 2, - "out_channels": 4, - "in_channels": 4, - "layers_per_block": 1, - "sample_size": 16, - } - model = UNet2DConditionModel(**init_dict) - return model.to(torch_device) - - def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules): + def get_lora_config(self, lora_rank, lora_alpha, target_modules): # from diffusers test_models_unet_2d_condition.py from peft import LoraConfig - unet_lora_config = LoraConfig( + lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules, init_lora_weights=False, use_dora=False, ) - return unet_lora_config + return lora_config - def get_dummy_input(self): - # from UNet2DConditionModelTests - batch_size = 4 - num_channels = 4 - sizes = (16, 16) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + def get_linear_module_name_other_than_attn(self, model): + linear_names = [ + name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name + ] + return linear_names[0] def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None): """ @@ -1858,23 +1833,27 @@ class TestLoraHotSwappingForModel(unittest.TestCase): fine. """ # create 2 adapters with different ranks and alphas - dummy_input = self.get_dummy_input() + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + alpha0, alpha1 = rank0, rank1 max_rank = max([rank0, rank1]) if target_modules1 is None: target_modules1 = target_modules0[:] - lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0) - lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1) + lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0) + lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1) - unet = self.get_small_unet() - unet.add_adapter(lora_config0, adapter_name="adapter0") + model.add_adapter(lora_config0, adapter_name="adapter0") with torch.inference_mode(): - output0_before = unet(**dummy_input)["sample"] + torch.manual_seed(0) + output0_before = model(**inputs_dict)["sample"] - unet.add_adapter(lora_config1, adapter_name="adapter1") - unet.set_adapter("adapter1") + model.add_adapter(lora_config1, adapter_name="adapter1") + model.set_adapter("adapter1") with torch.inference_mode(): - output1_before = unet(**dummy_input)["sample"] + torch.manual_seed(0) + output1_before = model(**inputs_dict)["sample"] # sanity checks: tol = 5e-3 @@ -1884,40 +1863,43 @@ class TestLoraHotSwappingForModel(unittest.TestCase): with tempfile.TemporaryDirectory() as tmp_dirname: # save the adapter checkpoints - unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") - unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") - del unet + model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") + model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") + del model # load the first adapter - unet = self.get_small_unet() + torch.manual_seed(0) + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + if do_compile or (rank0 != rank1): # no need to prepare if the model is not compiled or if the ranks are identical - unet.enable_lora_hotswap(target_rank=max_rank) + model.enable_lora_hotswap(target_rank=max_rank) file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors") file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors") - unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) + model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) if do_compile: - unet = torch.compile(unet, mode="reduce-overhead") + model = torch.compile(model, mode="reduce-overhead") with torch.inference_mode(): - output0_after = unet(**dummy_input)["sample"] + output0_after = model(**inputs_dict)["sample"] assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) # hotswap the 2nd adapter - unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) + model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) # we need to call forward to potentially trigger recompilation with torch.inference_mode(): - output1_after = unet(**dummy_input)["sample"] + output1_after = model(**inputs_dict)["sample"] assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) # check error when not passing valid adapter name name = "does-not-exist" msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" with self.assertRaisesRegex(ValueError, msg): - unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) + model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_model(self, rank0, rank1): @@ -1934,6 +1916,9 @@ class TestLoraHotSwappingForModel(unittest.TestCase): @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + return + # It's important to add this context to raise an error on recompilation target_modules = ["conv", "conv1", "conv2"] with torch._dynamo.config.patch(error_on_recompile=True): @@ -1941,52 +1926,77 @@ class TestLoraHotSwappingForModel(unittest.TestCase): @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + return + # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "conv"] with torch._dynamo.config.patch(error_on_recompile=True): self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1): + # In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping + # with `torch.compile()` for models that have both linear and conv layers. In this test, we check + # if we can target a linear layer from the transformer blocks and another linear layer from non-attention + # block. + target_modules = ["to_q"] + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + target_modules.append(self.get_linear_module_name_other_than_attn(model)) + del model + + # It's important to add this context to raise an error on recompilation + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) + def test_enable_lora_hotswap_called_after_adapter_added_raises(self): # ensure that enable_lora_hotswap is called before loading the first adapter - lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) - unet = self.get_small_unet() - unet.add_adapter(lora_config) + lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") with self.assertRaisesRegex(RuntimeError, msg): - unet.enable_lora_hotswap(target_rank=32) + model.enable_lora_hotswap(target_rank=32) def test_enable_lora_hotswap_called_after_adapter_added_warning(self): # ensure that enable_lora_hotswap is called before loading the first adapter from diffusers.loaders.peft import logger - lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) - unet = self.get_small_unet() - unet.add_adapter(lora_config) + lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) msg = ( "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." ) with self.assertLogs(logger=logger, level="WARNING") as cm: - unet.enable_lora_hotswap(target_rank=32, check_compiled="warn") + model.enable_lora_hotswap(target_rank=32, check_compiled="warn") assert any(msg in log for log in cm.output) def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): # check possibility to ignore the error/warning - lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) - unet = self.get_small_unet() - unet.add_adapter(lora_config) + lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Capture all warnings - unet.enable_lora_hotswap(target_rank=32, check_compiled="warn") + model.enable_lora_hotswap(target_rank=32, check_compiled="warn") self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}") def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): # check that wrong argument value raises an error - lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) - unet = self.get_small_unet() - unet.add_adapter(lora_config) + lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") with self.assertRaisesRegex(ValueError, msg): - unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") + model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") def test_hotswap_second_adapter_targets_more_layers_raises(self): # check the error and log diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index f767d2196e..2a61d66135 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -22,7 +22,7 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor from diffusers.models.embeddings import ImageProjection from diffusers.utils.testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin enable_full_determinism() @@ -78,7 +78,9 @@ def create_flux_ip_adapter_state_dict(model): return ip_state_dict -class FluxTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): +class FluxTransformerTests( + ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase +): model_class = FluxTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py new file mode 100644 index 0000000000..5f485b210f --- /dev/null +++ b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py @@ -0,0 +1,116 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HunyuanVideoFramepackTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanVideoFramepackTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + model_split_percents = [0.5, 0.7, 0.9] + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 3 + height = 4 + width = 4 + text_encoder_embedding_dim = 16 + image_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) + image_embeds = torch.randn((batch_size, sequence_length, image_encoder_embedding_dim)).to(torch_device) + indices_latents = torch.ones((3,)).to(torch_device) + latents_clean = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) + indices_latents_clean = torch.ones((num_frames - 1,)).to(torch_device) + latents_history_2x = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) + indices_latents_history_2x = torch.ones((num_frames - 1,)).to(torch_device) + latents_history_4x = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to( + torch_device + ) + indices_latents_history_4x = torch.ones(((num_frames - 1) * 4,)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_projections, + "encoder_attention_mask": encoder_attention_mask, + "guidance": guidance, + "image_embeds": image_embeds, + "indices_latents": indices_latents, + "latents_clean": latents_clean, + "indices_latents_clean": indices_latents_clean, + "latents_history_2x": latents_history_2x, + "indices_latents_history_2x": indices_latents_history_2x, + "latents_history_4x": latents_history_4x, + "indices_latents_history_4x": indices_latents_history_4x, + } + + @property + def input_shape(self): + return (4, 3, 4, 4) + + @property + def output_shape(self): + return (4, 3, 4, 4) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 2, + "patch_size_t": 1, + "guidance_embeds": True, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + "image_condition_type": None, + "has_image_proj": True, + "image_proj_dim": 16, + "has_clean_x_embedder": True, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanVideoFramepackTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index d01a0b4935..94a5d641a7 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -53,7 +53,7 @@ from diffusers.utils.testing_utils import ( torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin if is_peft_available(): @@ -350,7 +350,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): return custom_diffusion_attn_procs -class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class UNet2DConditionModelTests( + ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase +): model_class = UNet2DConditionModel main_input_name = "sample" # We override the items here because the unet under consideration is small. diff --git a/tests/pipelines/consisid/test_consisid.py b/tests/pipelines/consisid/test_consisid.py index a39c17bb4f..7284ecc6f4 100644 --- a/tests/pipelines/consisid/test_consisid.py +++ b/tests/pipelines/consisid/test_consisid.py @@ -24,9 +24,10 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler from diffusers.utils import load_image from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, - require_torch_gpu, + require_torch_accelerator, slow, torch_device, ) @@ -316,19 +317,19 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @slow -@require_torch_gpu +@require_torch_accelerator class ConsisIDPipelineIntegrationTests(unittest.TestCase): prompt = "A painting of a squirrel eating a burger." def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_consisid(self): generator = torch.Generator("cpu").manual_seed(0) @@ -338,8 +339,8 @@ class ConsisIDPipelineIntegrationTests(unittest.TestCase): prompt = self.prompt image = load_image("https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true") - id_vit_hidden = [torch.ones([1, 2, 2])] * 1 - id_cond = torch.ones(1, 2) + id_vit_hidden = [torch.ones([1, 577, 1024])] * 5 + id_cond = torch.ones(1, 1280) videos = pipe( image=image, @@ -357,5 +358,5 @@ class ConsisIDPipelineIntegrationTests(unittest.TestCase): video = videos[0] expected_video = torch.randn(1, 16, 480, 720, 3).numpy() - max_diff = numpy_cosine_similarity_distance(video, expected_video) + max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video) assert max_diff < 1e-3, f"Max diff is too high. got {video}" diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py index 18732c0058..65f39db078 100644 --- a/tests/pipelines/dit/test_dit.py +++ b/tests/pipelines/dit/test_dit.py @@ -21,7 +21,15 @@ import torch from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler from diffusers.utils import is_xformers_available -from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + load_numpy, + nightly, + numpy_cosine_similarity_distance, + require_torch_accelerator, + torch_device, +) from ..pipeline_params import ( CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS, @@ -107,23 +115,23 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @nightly -@require_torch_gpu +@require_torch_accelerator class DiTPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_dit_256(self): generator = torch.manual_seed(0) pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256") - pipe.to("cuda") + pipe.to(torch_device) words = ["vase", "umbrella", "white shark", "white wolf"] ids = pipe.get_label_ids(words) @@ -139,7 +147,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase): def test_dit_512(self): pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512") pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) - pipe.to("cuda") + pipe.to(torch_device) words = ["vase", "umbrella"] ids = pipe.get_label_ids(words) @@ -152,4 +160,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase): f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy" ) - assert np.abs((expected_image - image).max()) < 1e-1 + expected_slice = expected_image.flatten() + output_slice = image.flatten() + + assert numpy_cosine_similarity_distance(expected_slice, output_slice) < 1e-2 diff --git a/tests/pipelines/easyanimate/test_easyanimate.py b/tests/pipelines/easyanimate/test_easyanimate.py index 13d5c2f49b..161734a166 100644 --- a/tests/pipelines/easyanimate/test_easyanimate.py +++ b/tests/pipelines/easyanimate/test_easyanimate.py @@ -27,9 +27,10 @@ from diffusers import ( FlowMatchEulerDiscreteScheduler, ) from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, - require_torch_gpu, + require_torch_accelerator, slow, torch_device, ) @@ -256,19 +257,19 @@ class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase): @slow -@require_torch_gpu +@require_torch_accelerator class EasyAnimatePipelineIntegrationTests(unittest.TestCase): prompt = "A painting of a squirrel eating a burger." def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_EasyAnimate(self): generator = torch.Generator("cpu").manual_seed(0) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index ea2d015af5..55b0efbe60 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -27,8 +27,8 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, nightly, numpy_cosine_similarity_distance, - require_big_gpu_with_torch_cuda, - require_torch_gpu, + require_big_accelerator, + require_torch_accelerator, torch_device, ) @@ -266,9 +266,9 @@ class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unitte @nightly -@require_torch_gpu -@require_big_gpu_with_torch_cuda -@pytest.mark.big_gpu_with_torch_cuda +@require_torch_accelerator +@require_big_accelerator +@pytest.mark.big_accelerator class MochiPipelineIntegrationTests(unittest.TestCase): prompt = "A painting of a squirrel eating a burger." @@ -302,5 +302,5 @@ class MochiPipelineIntegrationTests(unittest.TestCase): video = videos[0] expected_video = torch.randn(1, 19, 480, 848, 3).numpy() - max_diff = numpy_cosine_similarity_distance(video, expected_video) + max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video) assert max_diff < 1e-3, f"Max diff is too high. got {video}" diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index 2f9c4d4e3f..e8f84eb913 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -7,8 +7,10 @@ from transformers import AutoTokenizer from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel from diffusers.utils.testing_utils import ( + Expectations, + backend_empty_cache, numpy_cosine_similarity_distance, - require_torch_gpu, + require_torch_accelerator, slow, torch_device, ) @@ -87,7 +89,7 @@ class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @slow -@require_torch_gpu +@require_torch_accelerator class OmniGenPipelineSlowTests(unittest.TestCase): pipeline_class = OmniGenPipeline repo_id = "shitao/OmniGen-v1-diffusers" @@ -95,12 +97,12 @@ class OmniGenPipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, seed=0): if str(device).startswith("mps"): @@ -125,21 +127,56 @@ class OmniGenPipelineSlowTests(unittest.TestCase): image = pipe(**inputs).images[0] image_slice = image[0, :10, :10] - expected_slice = np.array( - [ - [0.1783447, 0.16772744, 0.14339337], - [0.17066911, 0.15521264, 0.13757327], - [0.17072496, 0.15531206, 0.13524258], - [0.16746324, 0.1564025, 0.13794944], - [0.16490817, 0.15258026, 0.13697758], - [0.16971767, 0.15826806, 0.13928896], - [0.16782972, 0.15547255, 0.13783783], - [0.16464645, 0.15281534, 0.13522372], - [0.16535294, 0.15301755, 0.13526791], - [0.16365296, 0.15092957, 0.13443318], - ], - dtype=np.float32, + expected_slices = Expectations( + { + ("xpu", 3): np.array( + [ + [0.05859375, 0.05859375, 0.04492188], + [0.04882812, 0.04101562, 0.03320312], + [0.04882812, 0.04296875, 0.03125], + [0.04296875, 0.0390625, 0.03320312], + [0.04296875, 0.03710938, 0.03125], + [0.04492188, 0.0390625, 0.03320312], + [0.04296875, 0.03710938, 0.03125], + [0.04101562, 0.03710938, 0.02734375], + [0.04101562, 0.03515625, 0.02734375], + [0.04101562, 0.03515625, 0.02929688], + ], + dtype=np.float32, + ), + ("cuda", 7): np.array( + [ + [0.1783447, 0.16772744, 0.14339337], + [0.17066911, 0.15521264, 0.13757327], + [0.17072496, 0.15531206, 0.13524258], + [0.16746324, 0.1564025, 0.13794944], + [0.16490817, 0.15258026, 0.13697758], + [0.16971767, 0.15826806, 0.13928896], + [0.16782972, 0.15547255, 0.13783783], + [0.16464645, 0.15281534, 0.13522372], + [0.16535294, 0.15301755, 0.13526791], + [0.16365296, 0.15092957, 0.13443318], + ], + dtype=np.float32, + ), + ("cuda", 8): np.array( + [ + [0.0546875, 0.05664062, 0.04296875], + [0.046875, 0.04101562, 0.03320312], + [0.05078125, 0.04296875, 0.03125], + [0.04296875, 0.04101562, 0.03320312], + [0.0390625, 0.03710938, 0.02929688], + [0.04296875, 0.03710938, 0.03125], + [0.0390625, 0.03710938, 0.02929688], + [0.0390625, 0.03710938, 0.02734375], + [0.0390625, 0.03320312, 0.02734375], + [0.0390625, 0.03320312, 0.02734375], + ], + dtype=np.float32, + ), + } ) + expected_slice = expected_slices.get_expectation() max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py index 6b668de276..4192bc71ca 100644 --- a/tests/pipelines/paint_by_example/test_paint_by_example.py +++ b/tests/pipelines/paint_by_example/test_paint_by_example.py @@ -25,11 +25,12 @@ from transformers import CLIPImageProcessor, CLIPVisionConfig from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_image, nightly, - require_torch_gpu, + require_torch_accelerator, torch_device, ) @@ -174,19 +175,19 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): @nightly -@require_torch_gpu +@require_torch_accelerator class PaintByExamplePipelineIntegrationTests(unittest.TestCase): def setUp(self): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_paint_by_example(self): # make sure here that pndm scheduler skips prk diff --git a/tests/pipelines/stable_audio/test_stable_audio.py b/tests/pipelines/stable_audio/test_stable_audio.py index 01df82056c..f8f4803ccf 100644 --- a/tests/pipelines/stable_audio/test_stable_audio.py +++ b/tests/pipelines/stable_audio/test_stable_audio.py @@ -32,7 +32,14 @@ from diffusers import ( StableAudioProjectionModel, ) from diffusers.utils import is_xformers_available -from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device +from diffusers.utils.testing_utils import ( + Expectations, + backend_empty_cache, + enable_full_determinism, + nightly, + require_torch_accelerator, + torch_device, +) from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -419,17 +426,17 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @nightly -@require_torch_gpu +@require_torch_accelerator class StableAudioPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): generator = torch.Generator(device=generator_device).manual_seed(seed) @@ -459,9 +466,15 @@ class StableAudioPipelineIntegrationTests(unittest.TestCase): # check the portion of the generated audio with the largest dynamic range (reduces flakiness) audio_slice = audio[0, 447590:447600] # fmt: off - expected_slice = np.array( - [-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060] + expected_slices = Expectations( + { + ("xpu", 3): np.array([-0.0285, 0.1083, 0.1863, 0.3165, 0.5312, 0.6971, 0.6958, 0.6177, 0.5598, 0.5048]), + ("cuda", 7): np.array([-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]), + ("cuda", 8): np.array([-0.0285, 0.1082, 0.1862, 0.3163, 0.5306, 0.6964, 0.6953, 0.6172, 0.5593, 0.5044]), + } ) - # fmt: one + # fmt: on + + expected_slice = expected_slices.get_expectation() max_diff = np.abs(expected_slice - audio_slice.detach().cpu().numpy()).max() assert max_diff < 1.5e-3 diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 096ee4c344..11ee1f0b9f 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -389,7 +389,7 @@ class BnB4BitBasicTests(Base4bitTests): class BnB4BitTrainingTests(Base4bitTests): def setUp(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) nf4_config = BitsAndBytesConfig( load_in_4bit=True, @@ -657,7 +657,7 @@ class SlowBnb4BitTests(Base4bitTests): class SlowBnb4BitFluxTests(Base4bitTests): def setUp(self) -> None: gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) model_id = "hf-internal-testing/flux.1-dev-nf4-pkg" t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") @@ -674,7 +674,7 @@ class SlowBnb4BitFluxTests(Base4bitTests): del self.pipeline_4bit gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_quality(self): # keep the resolution and max tokens to a lower number for faster execution. @@ -722,7 +722,7 @@ class SlowBnb4BitFluxTests(Base4bitTests): class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests): def setUp(self) -> None: gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16) self.pipeline_4bit.enable_model_cpu_offload() @@ -731,7 +731,7 @@ class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests): del self.pipeline_4bit gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_lora_loading(self): self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") diff --git a/tests/quantization/test_pipeline_level_quantization.py b/tests/quantization/test_pipeline_level_quantization.py new file mode 100644 index 0000000000..b82b2889d7 --- /dev/null +++ b/tests/quantization/test_pipeline_level_quantization.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch + +from diffusers import DiffusionPipeline, QuantoConfig +from diffusers.quantizers import PipelineQuantizationConfig +from diffusers.utils.testing_utils import ( + is_transformers_available, + require_accelerate, + require_bitsandbytes_version_greater, + require_quanto, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + + +if is_transformers_available(): + from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig +else: + TranBitsAndBytesConfig = None + + +@require_bitsandbytes_version_greater("0.43.2") +@require_quanto +@require_accelerate +@require_torch +@require_torch_accelerator +@slow +class PipelineQuantizationTests(unittest.TestCase): + model_name = "hf-internal-testing/tiny-flux-pipe" + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + def test_quant_config_set_correctly_through_kwargs(self): + components_to_quantize = ["transformer", "text_encoder_2"] + quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + }, + components_to_quantize=components_to_quantize, + ) + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + for name, component in pipe.components.items(): + if name in components_to_quantize: + self.assertTrue(getattr(component.config, "quantization_config", None) is not None) + quantization_config = component.config.quantization_config + self.assertTrue(quantization_config.load_in_4bit) + self.assertTrue(quantization_config.quant_method == "bitsandbytes") + + _ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) + + def test_quant_config_set_correctly_through_granular(self): + quant_config = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + ) + components_to_quantize = list(quant_config.quant_mapping.keys()) + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + for name, component in pipe.components.items(): + if name in components_to_quantize: + self.assertTrue(getattr(component.config, "quantization_config", None) is not None) + quantization_config = component.config.quantization_config + + if name == "text_encoder_2": + self.assertTrue(quantization_config.load_in_4bit) + self.assertTrue(quantization_config.quant_method == "bitsandbytes") + else: + self.assertTrue(quantization_config.quant_method == "quanto") + + _ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) + + def test_raises_error_for_invalid_config(self): + with self.assertRaises(ValueError) as err_context: + _ = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + }, + quant_backend="bitsandbytes_4bit", + ) + + self.assertTrue( + str(err_context.exception) + == "Both `quant_backend` and `quant_mapping` cannot be specified at the same time." + ) + + def test_validation_for_kwargs(self): + components_to_quantize = ["transformer", "text_encoder_2"] + with self.assertRaises(ValueError) as err_context: + _ = PipelineQuantizationConfig( + quant_backend="quanto", + quant_kwargs={"weights_dtype": "int8"}, + components_to_quantize=components_to_quantize, + ) + + self.assertTrue( + "The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception) + ) + + def test_raises_error_for_wrong_config_class(self): + quant_config = { + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + with self.assertRaises(ValueError) as err_context: + _ = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ) + self.assertTrue( + str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`." + ) + + def test_validation_for_mapping(self): + with self.assertRaises(ValueError) as err_context: + _ = PipelineQuantizationConfig( + quant_mapping={ + "transformer": DiffusionPipeline(), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + ) + + self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception)) + + def test_saving_loading(self): + quant_config = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + ) + components_to_quantize = list(quant_config.quant_mapping.keys()) + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"} + output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device) + for name, component in loaded_pipe.components.items(): + if name in components_to_quantize: + self.assertTrue(getattr(component.config, "quantization_config", None) is not None) + quantization_config = component.config.quantization_config + + if name == "text_encoder_2": + self.assertTrue(quantization_config.load_in_4bit) + self.assertTrue(quantization_config.quant_method == "bitsandbytes") + else: + self.assertTrue(quantization_config.quant_method == "quanto") + + output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images + + self.assertTrue(torch.allclose(output_1, output_2)) diff --git a/utils/print_env.py b/utils/print_env.py index 0a1cfbef13..2d2acb59d5 100644 --- a/utils/print_env.py +++ b/utils/print_env.py @@ -34,13 +34,24 @@ try: print("Torch version:", torch.__version__) print("Cuda available:", torch.cuda.is_available()) - print("Cuda version:", torch.version.cuda) - print("CuDNN version:", torch.backends.cudnn.version()) - print("Number of GPUs available:", torch.cuda.device_count()) if torch.cuda.is_available(): + print("Cuda version:", torch.version.cuda) + print("CuDNN version:", torch.backends.cudnn.version()) + print("Number of GPUs available:", torch.cuda.device_count()) device_properties = torch.cuda.get_device_properties(0) total_memory = device_properties.total_memory / (1024**3) print(f"CUDA memory: {total_memory} GB") + + print("XPU available:", hasattr(torch, "xpu") and torch.xpu.is_available()) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + print("XPU model:", torch.xpu.get_device_properties(0).name) + print("XPU compiler version:", torch.version.xpu) + print("Number of XPUs available:", torch.xpu.device_count()) + device_properties = torch.xpu.get_device_properties(0) + total_memory = device_properties.total_memory / (1024**3) + print(f"XPU memory: {total_memory} GB") + + except ImportError: print("Torch version:", None)