1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into group-offloading-with-disk

This commit is contained in:
Sayak Paul
2025-06-12 11:08:56 +05:30
committed by GitHub
46 changed files with 521 additions and 132 deletions

View File

@@ -416,6 +416,45 @@ text_encoder_2_4bit.dequantize()
transformer_4bit.dequantize()
```
## torch.compile
Speed up inference with `torch.compile`. Make sure you have the latest `bitsandbytes` installed and we also recommend installing [PyTorch nightly](https://pytorch.org/get-started/locally/).
<hfoptions id="bnb">
<hfoption id="8-bit">
```py
torch._dynamo.config.capture_dynamic_output_shape_ops = True
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_4bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
transformer_4bit.compile(fullgraph=True)
```
</hfoption>
<hfoption id="4-bit">
```py
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True)
transformer_4bit = AutoModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
transformer_4bit.compile(fullgraph=True)
```
</hfoption>
</hfoptions>
On an RTX 4090 with compilation, 4-bit Flux generation completed in 25.809 seconds versus 32.570 seconds without.
Check out the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) for more details.
## Resources
* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)

View File

@@ -81,12 +81,17 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
from ..quantizers.gguf.utils import dequantize_gguf_tensor
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_bnb_8bit_quantized and not is_bitsandbytes_available():
raise ValueError(
"The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
)
if is_gguf_quantized and not is_gguf_available():
raise ValueError(
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
@@ -97,10 +102,10 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
weight_on_cpu = True
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
if is_bnb_4bit_quantized:
if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.to(device) if weight_on_cpu else module.weight,
state=module.weight.quant_state,
state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
dtype=model.dtype,
).data
elif is_gguf_quantized:

View File

@@ -1149,9 +1149,7 @@ def get_1d_rotary_pos_embed(
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
is_npu = freqs.device.type == "npu"

View File

@@ -816,14 +816,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be defined for each
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
Examples:
```py
>>> from diffusers import AutoModel
>>> import torch
>>> # This works.
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
... )
>>> # This also works (integer accelerator device ID).
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
... )
>>> # Specifying a supported offloading strategy like "auto" also works.
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
... )
>>> # Specifying a dictionary as `device_map` also works.
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0",
... subfolder="unet",
... device_map={"": torch.device("cuda")},
... )
```
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
can also refer to the [Diffusers-specific
documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
for more concrete examples.
max_memory (`Dict`, *optional*):
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
each GPU and the available CPU RAM if unset.
@@ -1389,7 +1418,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
low_cpu_mem_usage: bool = True,
dtype: Optional[Union[str, torch.dtype]] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
device_map: Dict[str, Union[int, str, torch.device]] = None,
device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None,
offload_state_dict: Optional[bool] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,

View File

@@ -898,6 +898,7 @@ class FluxPipeline(
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:

View File

@@ -1193,6 +1193,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
if padding_mask_crop is not None:
image = [
self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image
]
# Offload all models
self.maybe_free_model_hooks()

View File

@@ -669,14 +669,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Mirror source to resolve accessibility issues if youre downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesnt need to be defined for each
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
same device.
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
device_map (`str`, *optional*):
Strategy that dictates how the different components of a pipeline should be placed on available
devices. Currently, only "balanced" `device_map` is supported. Check out
[this](https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement)
to know more.
max_memory (`Dict`, *optional*):
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
each GPU and the available CPU RAM if unset.

View File

@@ -388,8 +388,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
instead. Ignored when not using guidance (`guidance_scale` < `1`).
height (`int`, defaults to `480`):
The height in pixels of the generated image.
width (`int`, defaults to `832`):
@@ -434,8 +436,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
The dtype to use for the torch.amp.autocast.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
truncated. If the prompt is shorter, it will be padded to this length.
Examples:

View File

@@ -562,12 +562,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, *optional*, defaults to `512`):
The maximum sequence length of the prompt.
shift (`float`, *optional*, defaults to `5.0`):
The shift of the flow.
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
The dtype to use for the torch.amp.autocast.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
truncated. If the prompt is shorter, it will be padded to this length.
Examples:
Returns:

View File

@@ -687,8 +687,33 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
video (`List[PIL.Image.Image]`, *optional*):
The input video or videos to be used as a starting point for the generation. The video should be a list
of PIL images, a numpy array, or a torch tensor. Currently, the pipeline only supports generating one
video at a time.
mask (`List[PIL.Image.Image]`, *optional*):
The input mask defines which video regions to condition on and which to generate. Black areas in the
mask indicate conditioning regions, while white areas indicate regions for generation. The mask should
be a list of PIL images, a numpy array, or a torch tensor. Currently supports generating a single video
at a time.
reference_images (`List[PIL.Image.Image]`, *optional*):
A list of one or more reference images as extra conditioning for the generation. For example, if you
are trying to inpaint a video to change the character, you can pass reference images of the new
character here. Refer to the Diffusers [examples](https://github.com/huggingface/diffusers/pull/11582)
and original [user
guide](https://github.com/ali-vilab/VACE/blob/0897c6d055d7d9ea9e191dce763006664d9780f8/UserGuide.md)
for a full list of supported tasks and use cases.
conditioning_scale (`float`, `List[float]`, `torch.Tensor`, defaults to `1.0`):
The conditioning scale to be applied when adding the control conditioning latent stream to the
denoising latent stream in each control layer of the model. If a float is provided, it will be applied
uniformly to all layers. If a list or tensor is provided, it should have the same length as the number
of control layers in the model (`len(transformer.config.vace_layers)`).
height (`int`, defaults to `480`):
The height in pixels of the generated image.
width (`int`, defaults to `832`):
@@ -733,8 +758,9 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
The dtype to use for the torch.amp.autocast.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
truncated. If the prompt is shorter, it will be padded to this length.
Examples:

View File

@@ -419,12 +419,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)
if latents is None:
if isinstance(generator, list):
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
@@ -441,7 +436,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if hasattr(self.scheduler, "add_noise"):
latents = self.scheduler.add_noise(init_latents, noise, timestep)
else:
latents = self.scheduelr.scale_noise(init_latents, timestep, noise)
latents = self.scheduler.scale_noise(init_latents, timestep, noise)
else:
latents = latents.to(device)
@@ -513,7 +508,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
instead.
height (`int`, defaults to `480`):
The height in pixels of the generated image.
@@ -530,6 +525,8 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
strength (`float`, defaults to `0.8`):
Higher strength leads to more differences between original image and generated video.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -559,8 +556,9 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
The dtype to use for the torch.amp.autocast.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
truncated. If the prompt is shorter, it will be padded to this length.
Examples:

View File

@@ -493,7 +493,7 @@ class TorchAoConfig(QuantizationConfigMixin):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
@@ -645,7 +645,7 @@ class TorchAoConfig(QuantizationConfigMixin):
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
if cls._is_cuda_capability_atleast_8_9():
if cls._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
return QUANTIZATION_TYPES
@@ -655,14 +655,16 @@ class TorchAoConfig(QuantizationConfigMixin):
)
@staticmethod
def _is_cuda_capability_atleast_8_9() -> bool:
if not torch.cuda.is_available():
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
elif torch.xpu.is_available():
return True
else:
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
def get_apply_tensor_subclass(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()

View File

@@ -154,12 +154,30 @@ def check_imports(filename):
return get_relative_imports(filename)
def get_class_in_module(class_name, module_path):
def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
try:
module = importlib.import_module(module_path)
except ModuleNotFoundError as e:
# This can happen when the repo id contains ".", which Python's import machinery interprets as a directory
# separator. We do a bit of monkey patching to detect and fix this case.
if not (
pretrained_model_name_or_path is not None
and "." in pretrained_model_name_or_path
and module_path.startswith("diffusers_modules")
and pretrained_model_name_or_path.replace("/", "--") in module_path
):
raise e # We can't figure this one out, just reraise the original error
corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py"
corrected_path = corrected_path.replace(
pretrained_model_name_or_path.replace("/", "--").replace(".", "/"),
pretrained_model_name_or_path.replace("/", "--"),
)
module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module()
if class_name is None:
return find_pipeline_class(module)
@@ -454,4 +472,4 @@ def get_class_from_dynamic_module(
revision=revision,
local_files_only=local_files_only,
)
return get_class_in_module(class_name, final_module.replace(".py", ""))
return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path)

View File

@@ -99,6 +99,7 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA
else:
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False
_torch_version = "N/A"
_jax_version = "N/A"
_flax_version = "N/A"

View File

@@ -291,6 +291,18 @@ def require_torch_version_greater_equal(torch_version):
return decorator
def require_torch_version_greater(torch_version):
"""Decorator marking a test that requires torch with a specific version greater."""
def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
)(test_case)
return decorator
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
@@ -300,9 +312,7 @@ def require_torch_gpu(test_case):
def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
if not torch.cuda.is_available():
return unittest.skip(test_case)
else:
if torch.cuda.is_available():
current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),

View File

@@ -21,6 +21,7 @@ import torch
from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
slow,
@@ -162,13 +163,13 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_encode_decode(self):

View File

@@ -22,6 +22,7 @@ import torch
from diffusers import UNet2DModel
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
@@ -229,7 +230,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
# two models don't need to stay in the device at the same time
del model_accelerate
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained(

View File

@@ -978,13 +978,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@require_torch_gpu
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
@@ -994,13 +994,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
@@ -1084,6 +1084,42 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@parameterized.expand(
[
(-1, "You can't pass device_map as a negative int"),
("foo", "When passing device_map as a string, the value needs to be a device name"),
]
)
def test_wrong_device_map_raises_error(self, device_map, msg_substring):
with self.assertRaises(ValueError) as err_ctx:
_ = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
)
assert msg_substring in str(err_ctx.exception)
@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])
@require_torch_gpu
def test_passing_non_dict_device_map_works(self, device_map):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
)
output = loaded_model(**inputs_dict)
assert output.sample.shape == (4, 4, 16, 16)
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
@require_torch_gpu
def test_passing_dict_device_map_works(self, name, device_map):
# There are other valid dict-based `device_map` values too. It's best to refer to
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map}
)
output = loaded_model(**inputs_dict)
assert output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@@ -24,6 +24,7 @@ from transformers import AutoTokenizer, T5Config, T5EncoderModel
from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_hf_hub_version_greater,
@@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_allegro(self):
generator = torch.Generator("cpu").manual_seed(0)

View File

@@ -37,7 +37,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device
from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -378,12 +378,12 @@ class AudioLDMPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -423,12 +423,12 @@ class AudioLDMPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -45,7 +45,13 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils.testing_utils import enable_full_determinism, is_torch_version, nightly, torch_device
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
is_torch_version,
nightly,
torch_device,
)
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -540,12 +546,12 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
@@ -334,12 +335,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0)

View File

@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
@@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_cogview3plus(self):
generator = torch.Generator("cpu").manual_seed(0)

View File

@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
@@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")

View File

@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
@@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")

View File

@@ -221,7 +221,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline

View File

@@ -25,6 +25,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
load_numpy,
@@ -135,7 +136,7 @@ class IFPipelineSlowTests(unittest.TestCase):
image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated()
mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(

View File

@@ -24,6 +24,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
floats_tensor,
@@ -151,7 +152,7 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase):
)
image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated()
mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(

View File

@@ -224,7 +224,7 @@ class FluxPipelineFastTests(
@nightly
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -312,7 +312,7 @@ class FluxPipelineSlowTests(unittest.TestCase):
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"

View File

@@ -19,7 +19,7 @@ from diffusers.utils.testing_utils import (
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline
repo_id = "black-forest-labs/FLUX.1-Redux-dev"

View File

@@ -23,6 +23,7 @@ from transformers import AutoTokenizer, BertModel, T5EncoderModel
from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
@@ -310,12 +311,12 @@ class HunyuanDiTPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_hunyuan_dit_1024(self):
generator = torch.Generator("cpu").manual_seed(0)

View File

@@ -27,6 +27,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
@@ -231,12 +232,12 @@ class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_attend_and_excite_fp16(self):
generator = torch.manual_seed(51)

View File

@@ -25,6 +25,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism,
@@ -287,6 +288,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
output_type="np",
)
mem_bytes = torch.cuda.max_memory_allocated()
mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.65 GB is allocated
assert mem_bytes < 2.65 * 10**9

View File

@@ -233,7 +233,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

View File

@@ -168,7 +168,7 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

View File

@@ -35,6 +35,7 @@ from diffusers import (
UniPCMultistepScheduler,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
numpy_cosine_similarity_distance,
@@ -940,12 +941,12 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_stable_diffusion_lcm(self):
torch.manual_seed(0)

View File

@@ -39,6 +39,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
@@ -670,12 +671,12 @@ class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_stable_diffusion_xl_img2img_playground(self):
torch.manual_seed(0)

View File

@@ -1105,6 +1105,21 @@ class CustomPipelineTests(unittest.TestCase):
assert images.shape == (1, 64, 64, 3)
def test_remote_custom_pipe_with_dot_in_name(self):
# make sure that trust remote code has to be passed
with self.assertRaises(ValueError):
pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name")
pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name", trust_remote_code=True)
assert pipeline.__class__.__name__ == "CustomPipeline"
pipeline = pipeline.to(torch_device)
images, output_str = pipeline(num_inference_steps=2, output_type="np")
assert images[0].shape == (1, 32, 32, 3)
assert output_str == "This is a test"
def test_local_custom_pipeline_repo(self):
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
pipeline = DiffusionPipeline.from_pretrained(
@@ -1203,13 +1218,13 @@ class PipelineFastTests(unittest.TestCase):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def dummy_image(self):
batch_size = 1

View File

@@ -21,9 +21,11 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -144,12 +146,12 @@ class WanPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@unittest.skip("TODO: test needs to be implemented")
def test_Wanx(self):

View File

@@ -30,6 +30,7 @@ from diffusers import (
FluxTransformer2DModel,
SD3Transformer2DModel,
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
@@ -44,11 +45,14 @@ from diffusers.utils.testing_utils import (
require_peft_backend,
require_torch,
require_torch_accelerator,
require_torch_version_greater,
require_transformers_version_greater,
slow,
torch_device,
)
from ..test_torch_compile_utils import QuantCompileTests
def get_some_linear_layer(model):
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
@@ -855,3 +859,26 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
def test_fp4_double_safe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
@require_torch_version_greater("2.7.1")
class Bnb4BitCompileTests(QuantCompileTests):
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer", "text_encoder_2"],
)
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)

View File

@@ -19,15 +19,18 @@ import unittest
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from PIL import Image
from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxControlPipeline,
FluxTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
logging,
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
@@ -39,14 +42,18 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_accelerate,
require_bitsandbytes_version_greater,
require_peft_backend,
require_peft_version_greater,
require_torch,
require_torch_accelerator,
require_torch_version_greater_equal,
require_transformers_version_greater,
slow,
torch_device,
)
from ..test_torch_compile_utils import QuantCompileTests
def get_some_linear_layer(model):
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
@@ -697,6 +704,50 @@ class SlowBnb8bitFluxTests(Base8bitTests):
self.assertTrue(max_diff < 1e-3)
@require_transformers_version_greater("4.44.0")
@require_peft_backend
class SlowBnb4BitFluxControlWithLoraTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
backend_empty_cache(torch_device)
self.pipeline_8bit = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=["transformer", "text_encoder_2"],
),
torch_dtype=torch.float16,
)
self.pipeline_8bit.enable_model_cpu_offload()
def tearDown(self):
del self.pipeline_8bit
gc.collect()
backend_empty_cache(torch_device)
def test_lora_loading(self):
self.pipeline_8bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
output = self.pipeline_8bit(
prompt=self.prompt,
control_image=Image.new(mode="RGB", size=(256, 256)),
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.2029, 0.2136, 0.2268, 0.1921, 0.1997, 0.2185, 0.2021, 0.2183, 0.2292])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
@slow
class BaseBnb8bitSerializationTests(Base8bitTests):
def setUp(self):
@@ -773,3 +824,27 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
out_0 = self.model_0(**inputs)[0]
out_1 = model_1(**inputs)[0]
self.assertTrue(torch.equal(out_0, out_1))
@require_torch_version_greater_equal("2.6.0")
class Bnb8BitCompileTests(QuantCompileTests):
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=["transformer", "text_encoder_2"],
)
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)

View File

@@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers import DiffusionPipeline
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
@require_torch_gpu
@slow
class QuantCompileTests(unittest.TestCase):
quantization_config = None
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()
def _init_pipeline(self, quantization_config, torch_dtype):
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
)
return pipe
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
# import to ensure fullgraph True
pipe.transformer.compile(fullgraph=True)
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
torch._dynamo.config.cache_size_limit = 10000
pipe = self._init_pipeline(quantization_config, torch_dtype)
group_offload_kwargs = {
"onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": True,
"non_blocking": True,
}
pipe.transformer.enable_group_offload(**group_offload_kwargs)
pipe.transformer.compile()
for name, component in pipe.components.items():
if name != "transformer" and isinstance(component, torch.nn.Module):
if torch.device(component.device).type == "cpu":
component.to("cuda")
for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)

View File

@@ -30,13 +30,15 @@ from diffusers import (
)
from diffusers.models.attention_processor import Attention
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_synchronize,
enable_full_determinism,
is_torch_available,
is_torchao_available,
nightly,
numpy_cosine_similarity_distance,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_torchao_version_greater_or_equal,
slow,
torch_device,
@@ -61,7 +63,7 @@ if is_torchao_available():
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
@@ -79,7 +81,7 @@ class TorchAoConfigTest(unittest.TestCase):
Test kwargs validations in TorchAoConfig
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"):
with self.assertRaisesRegex(ValueError, "is not supported"):
_ = TorchAoConfig("uint8")
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
@@ -119,12 +121,12 @@ class TorchAoConfigTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_dummy_components(
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
@@ -269,6 +271,7 @@ class TorchAoTest(unittest.TestCase):
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map=f"{torch_device}:0",
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
@@ -338,7 +341,7 @@ class TorchAoTest(unittest.TestCase):
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
@@ -359,7 +362,7 @@ class TorchAoTest(unittest.TestCase):
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
@@ -518,14 +521,14 @@ class TorchAoTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
@@ -593,17 +596,17 @@ class TorchAoSerializationTest(unittest.TestCase):
)
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_int_a8w8_cuda(self):
def test_int_a8w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cuda"
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
def test_int_a16w8_cuda(self):
def test_int_a16w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cuda"
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
@@ -624,14 +627,14 @@ class TorchAoSerializationTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_dummy_components(self, quantization_config: TorchAoConfig):
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
@@ -713,8 +716,8 @@ class SlowTorchAoTests(unittest.TestCase):
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
self._test_quant_type(quantization_config, expected_slice)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
backend_empty_cache(torch_device)
backend_synchronize(torch_device)
def test_serialization_int8wo(self):
quantization_config = TorchAoConfig("int8wo")
@@ -733,8 +736,8 @@ class SlowTorchAoTests(unittest.TestCase):
pipe.remove_all_hooks()
del pipe.transformer
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
backend_empty_cache(torch_device)
backend_synchronize(torch_device)
transformer = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
)
@@ -783,14 +786,14 @@ class SlowTorchAoTests(unittest.TestCase):
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
if str(device).startswith("mps"):

View File

@@ -16,8 +16,6 @@
import gc
import unittest
import torch
from diffusers import (
Lumina2Transformer2DModel,
)
@@ -66,9 +64,9 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

View File

@@ -16,8 +16,6 @@
import gc
import unittest
import torch
from diffusers import (
FluxTransformer2DModel,
)
@@ -64,9 +62,9 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

View File

@@ -1,8 +1,6 @@
import gc
import unittest
import torch
from diffusers import (
SanaTransformer2DModel,
)
@@ -53,9 +51,9 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)