mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into group-offloading-with-disk
This commit is contained in:
@@ -416,6 +416,45 @@ text_encoder_2_4bit.dequantize()
|
||||
transformer_4bit.dequantize()
|
||||
```
|
||||
|
||||
## torch.compile
|
||||
|
||||
Speed up inference with `torch.compile`. Make sure you have the latest `bitsandbytes` installed and we also recommend installing [PyTorch nightly](https://pytorch.org/get-started/locally/).
|
||||
|
||||
<hfoptions id="bnb">
|
||||
<hfoption id="8-bit">
|
||||
```py
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_4bit = AutoModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
transformer_4bit.compile(fullgraph=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="4-bit">
|
||||
|
||||
```py
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True)
|
||||
transformer_4bit = AutoModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
transformer_4bit.compile(fullgraph=True)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
On an RTX 4090 with compilation, 4-bit Flux generation completed in 25.809 seconds versus 32.570 seconds without.
|
||||
|
||||
Check out the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) for more details.
|
||||
|
||||
## Resources
|
||||
|
||||
* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
|
||||
|
||||
@@ -81,12 +81,17 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
|
||||
from ..quantizers.gguf.utils import dequantize_gguf_tensor
|
||||
|
||||
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
|
||||
is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
|
||||
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
|
||||
|
||||
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
|
||||
raise ValueError(
|
||||
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
|
||||
)
|
||||
if is_bnb_8bit_quantized and not is_bitsandbytes_available():
|
||||
raise ValueError(
|
||||
"The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
|
||||
)
|
||||
if is_gguf_quantized and not is_gguf_available():
|
||||
raise ValueError(
|
||||
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
|
||||
@@ -97,10 +102,10 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
|
||||
weight_on_cpu = True
|
||||
|
||||
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
||||
if is_bnb_4bit_quantized:
|
||||
if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
|
||||
module_weight = dequantize_bnb_weight(
|
||||
module.weight.to(device) if weight_on_cpu else module.weight,
|
||||
state=module.weight.quant_state,
|
||||
state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
|
||||
dtype=model.dtype,
|
||||
).data
|
||||
elif is_gguf_quantized:
|
||||
|
||||
@@ -1149,9 +1149,7 @@ def get_1d_rotary_pos_embed(
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
is_npu = freqs.device.type == "npu"
|
||||
|
||||
@@ -816,14 +816,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoModel
|
||||
>>> import torch
|
||||
|
||||
>>> # This works.
|
||||
>>> model = AutoModel.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
|
||||
... )
|
||||
>>> # This also works (integer accelerator device ID).
|
||||
>>> model = AutoModel.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
|
||||
... )
|
||||
>>> # Specifying a supported offloading strategy like "auto" also works.
|
||||
>>> model = AutoModel.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
|
||||
... )
|
||||
>>> # Specifying a dictionary as `device_map` also works.
|
||||
>>> model = AutoModel.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
... subfolder="unet",
|
||||
... device_map={"": torch.device("cuda")},
|
||||
... )
|
||||
```
|
||||
|
||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
|
||||
can also refer to the [Diffusers-specific
|
||||
documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
|
||||
for more concrete examples.
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
||||
each GPU and the available CPU RAM if unset.
|
||||
@@ -1389,7 +1418,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
low_cpu_mem_usage: bool = True,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
device_map: Dict[str, Union[int, str, torch.device]] = None,
|
||||
device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None,
|
||||
offload_state_dict: Optional[bool] = None,
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
|
||||
@@ -898,6 +898,7 @@ class FluxPipeline(
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
|
||||
@@ -1193,6 +1193,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [
|
||||
self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image
|
||||
]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
@@ -669,14 +669,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn’t need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
device_map (`str`, *optional*):
|
||||
Strategy that dictates how the different components of a pipeline should be placed on available
|
||||
devices. Currently, only "balanced" `device_map` is supported. Check out
|
||||
[this](https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement)
|
||||
to know more.
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
||||
each GPU and the available CPU RAM if unset.
|
||||
|
||||
@@ -388,8 +388,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (`guidance_scale` < `1`).
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
@@ -434,8 +436,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -562,12 +562,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to `512`):
|
||||
The maximum sequence length of the prompt.
|
||||
shift (`float`, *optional*, defaults to `5.0`):
|
||||
The shift of the flow.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -687,8 +687,33 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
video (`List[PIL.Image.Image]`, *optional*):
|
||||
The input video or videos to be used as a starting point for the generation. The video should be a list
|
||||
of PIL images, a numpy array, or a torch tensor. Currently, the pipeline only supports generating one
|
||||
video at a time.
|
||||
mask (`List[PIL.Image.Image]`, *optional*):
|
||||
The input mask defines which video regions to condition on and which to generate. Black areas in the
|
||||
mask indicate conditioning regions, while white areas indicate regions for generation. The mask should
|
||||
be a list of PIL images, a numpy array, or a torch tensor. Currently supports generating a single video
|
||||
at a time.
|
||||
reference_images (`List[PIL.Image.Image]`, *optional*):
|
||||
A list of one or more reference images as extra conditioning for the generation. For example, if you
|
||||
are trying to inpaint a video to change the character, you can pass reference images of the new
|
||||
character here. Refer to the Diffusers [examples](https://github.com/huggingface/diffusers/pull/11582)
|
||||
and original [user
|
||||
guide](https://github.com/ali-vilab/VACE/blob/0897c6d055d7d9ea9e191dce763006664d9780f8/UserGuide.md)
|
||||
for a full list of supported tasks and use cases.
|
||||
conditioning_scale (`float`, `List[float]`, `torch.Tensor`, defaults to `1.0`):
|
||||
The conditioning scale to be applied when adding the control conditioning latent stream to the
|
||||
denoising latent stream in each control layer of the model. If a float is provided, it will be applied
|
||||
uniformly to all layers. If a list or tensor is provided, it should have the same length as the number
|
||||
of control layers in the model (`len(transformer.config.vace_layers)`).
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
@@ -733,8 +758,9 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -419,12 +419,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
|
||||
@@ -441,7 +436,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
if hasattr(self.scheduler, "add_noise"):
|
||||
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
else:
|
||||
latents = self.scheduelr.scale_noise(init_latents, timestep, noise)
|
||||
latents = self.scheduler.scale_noise(init_latents, timestep, noise)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
@@ -513,7 +508,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
|
||||
instead.
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
@@ -530,6 +525,8 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
strength (`float`, defaults to `0.8`):
|
||||
Higher strength leads to more differences between original image and generated video.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -559,8 +556,9 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -493,7 +493,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
|
||||
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
|
||||
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
|
||||
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
|
||||
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
|
||||
@@ -645,7 +645,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
|
||||
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
|
||||
|
||||
if cls._is_cuda_capability_atleast_8_9():
|
||||
if cls._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
|
||||
|
||||
return QUANTIZATION_TYPES
|
||||
@@ -655,14 +655,16 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_cuda_capability_atleast_8_9() -> bool:
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
return minor >= 9
|
||||
return major >= 9
|
||||
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
return minor >= 9
|
||||
return major >= 9
|
||||
elif torch.xpu.is_available():
|
||||
return True
|
||||
else:
|
||||
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
|
||||
|
||||
def get_apply_tensor_subclass(self):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
|
||||
@@ -154,12 +154,30 @@ def check_imports(filename):
|
||||
return get_relative_imports(filename)
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path):
|
||||
def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None):
|
||||
"""
|
||||
Import a module on the cache directory for modules and extract a class from it.
|
||||
"""
|
||||
module_path = module_path.replace(os.path.sep, ".")
|
||||
module = importlib.import_module(module_path)
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ModuleNotFoundError as e:
|
||||
# This can happen when the repo id contains ".", which Python's import machinery interprets as a directory
|
||||
# separator. We do a bit of monkey patching to detect and fix this case.
|
||||
if not (
|
||||
pretrained_model_name_or_path is not None
|
||||
and "." in pretrained_model_name_or_path
|
||||
and module_path.startswith("diffusers_modules")
|
||||
and pretrained_model_name_or_path.replace("/", "--") in module_path
|
||||
):
|
||||
raise e # We can't figure this one out, just reraise the original error
|
||||
|
||||
corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py"
|
||||
corrected_path = corrected_path.replace(
|
||||
pretrained_model_name_or_path.replace("/", "--").replace(".", "/"),
|
||||
pretrained_model_name_or_path.replace("/", "--"),
|
||||
)
|
||||
module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module()
|
||||
|
||||
if class_name is None:
|
||||
return find_pipeline_class(module)
|
||||
@@ -454,4 +472,4 @@ def get_class_from_dynamic_module(
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path)
|
||||
|
||||
@@ -99,6 +99,7 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TORCH is set")
|
||||
_torch_available = False
|
||||
_torch_version = "N/A"
|
||||
|
||||
_jax_version = "N/A"
|
||||
_flax_version = "N/A"
|
||||
|
||||
@@ -291,6 +291,18 @@ def require_torch_version_greater_equal(torch_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torch_version_greater(torch_version):
|
||||
"""Decorator marking a test that requires torch with a specific version greater."""
|
||||
|
||||
def decorator(test_case):
|
||||
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
|
||||
@@ -300,9 +312,7 @@ def require_torch_gpu(test_case):
|
||||
|
||||
def require_torch_cuda_compatibility(expected_compute_capability):
|
||||
def decorator(test_case):
|
||||
if not torch.cuda.is_available():
|
||||
return unittest.skip(test_case)
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
current_compute_capability = get_torch_cuda_device_capability()
|
||||
return unittest.skipUnless(
|
||||
float(current_compute_capability) == float(expected_compute_capability),
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
|
||||
from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
slow,
|
||||
@@ -162,13 +163,13 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_encode_decode(self):
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
from diffusers import UNet2DModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
require_torch_accelerator,
|
||||
@@ -229,7 +230,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
# two models don't need to stay in the device at the same time
|
||||
del model_accelerate
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
model_normal_load, _ = UNet2DModel.from_pretrained(
|
||||
|
||||
@@ -978,13 +978,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
@@ -994,13 +994,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
@@ -1084,6 +1084,42 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(-1, "You can't pass device_map as a negative int"),
|
||||
("foo", "When passing device_map as a string, the value needs to be a device name"),
|
||||
]
|
||||
)
|
||||
def test_wrong_device_map_raises_error(self, device_map, msg_substring):
|
||||
with self.assertRaises(ValueError) as err_ctx:
|
||||
_ = self.model_class.from_pretrained(
|
||||
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
|
||||
)
|
||||
|
||||
assert msg_substring in str(err_ctx.exception)
|
||||
|
||||
@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])
|
||||
@require_torch_gpu
|
||||
def test_passing_non_dict_device_map_works(self, device_map):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
|
||||
)
|
||||
output = loaded_model(**inputs_dict)
|
||||
assert output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
|
||||
@require_torch_gpu
|
||||
def test_passing_dict_device_map_works(self, name, device_map):
|
||||
# There are other valid dict-based `device_map` values too. It's best to refer to
|
||||
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map}
|
||||
)
|
||||
output = loaded_model(**inputs_dict)
|
||||
assert output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -24,6 +24,7 @@ from transformers import AutoTokenizer, T5Config, T5EncoderModel
|
||||
|
||||
from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_hf_hub_version_greater,
|
||||
@@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_allegro(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
@@ -37,7 +37,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device
|
||||
from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
@@ -378,12 +378,12 @@ class AudioLDMPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
@@ -423,12 +423,12 @@ class AudioLDMPipelineNightlyTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
|
||||
@@ -45,7 +45,13 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, is_torch_version, nightly, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
is_torch_version,
|
||||
nightly,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
@@ -540,12 +546,12 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_accelerator,
|
||||
@@ -334,12 +335,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_cogvideox(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_accelerator,
|
||||
@@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_cogview3plus(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
@@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
|
||||
@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
@@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
|
||||
@@ -221,7 +221,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
|
||||
|
||||
@slow
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_accelerator
|
||||
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3ControlNetPipeline
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
load_numpy,
|
||||
@@ -135,7 +136,7 @@ class IFPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
assert mem_bytes < 12 * 10**9
|
||||
|
||||
expected_image = load_numpy(
|
||||
|
||||
@@ -24,6 +24,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
floats_tensor,
|
||||
@@ -151,7 +152,7 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
assert mem_bytes < 12 * 10**9
|
||||
|
||||
expected_image = load_numpy(
|
||||
|
||||
@@ -224,7 +224,7 @@ class FluxPipelineFastTests(
|
||||
|
||||
@nightly
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_accelerator
|
||||
class FluxPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPipeline
|
||||
repo_id = "black-forest-labs/FLUX.1-schnell"
|
||||
@@ -312,7 +312,7 @@ class FluxPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_accelerator
|
||||
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPipeline
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
@@ -19,7 +19,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
@slow
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_accelerator
|
||||
class FluxReduxSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxPriorReduxPipeline
|
||||
repo_id = "black-forest-labs/FLUX.1-Redux-dev"
|
||||
|
||||
@@ -23,6 +23,7 @@ from transformers import AutoTokenizer, BertModel, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_accelerator,
|
||||
@@ -310,12 +311,12 @@ class HunyuanDiTPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_hunyuan_dit_1024(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
@@ -27,6 +27,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -231,12 +232,12 @@ class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_attend_and_excite_fp16(self):
|
||||
generator = torch.manual_seed(51)
|
||||
|
||||
@@ -25,6 +25,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
@@ -287,6 +288,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
# make sure that less than 2.65 GB is allocated
|
||||
assert mem_bytes < 2.65 * 10**9
|
||||
|
||||
@@ -233,7 +233,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
|
||||
@slow
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_accelerator
|
||||
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
|
||||
@@ -168,7 +168,7 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
|
||||
|
||||
@slow
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_accelerator
|
||||
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Img2ImgPipeline
|
||||
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
|
||||
@@ -35,6 +35,7 @@ from diffusers import (
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -940,12 +941,12 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_stable_diffusion_lcm(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -39,6 +39,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
@@ -670,12 +671,12 @@ class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_stable_diffusion_xl_img2img_playground(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -1105,6 +1105,21 @@ class CustomPipelineTests(unittest.TestCase):
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
def test_remote_custom_pipe_with_dot_in_name(self):
|
||||
# make sure that trust remote code has to be passed
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name")
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name", trust_remote_code=True)
|
||||
|
||||
assert pipeline.__class__.__name__ == "CustomPipeline"
|
||||
|
||||
pipeline = pipeline.to(torch_device)
|
||||
images, output_str = pipeline(num_inference_steps=2, output_type="np")
|
||||
|
||||
assert images[0].shape == (1, 32, 32, 3)
|
||||
assert output_str == "This is a test"
|
||||
|
||||
def test_local_custom_pipeline_repo(self):
|
||||
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
@@ -1203,13 +1218,13 @@ class PipelineFastTests(unittest.TestCase):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
|
||||
@@ -21,9 +21,11 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
@@ -144,12 +146,12 @@ class WanPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@unittest.skip("TODO: test needs to be implemented")
|
||||
def test_Wanx(self):
|
||||
|
||||
@@ -30,6 +30,7 @@ from diffusers import (
|
||||
FluxTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
)
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.utils import is_accelerate_version, logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
@@ -44,11 +45,14 @@ from diffusers.utils.testing_utils import (
|
||||
require_peft_backend,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_version_greater,
|
||||
require_transformers_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_torch_compile_utils import QuantCompileTests
|
||||
|
||||
|
||||
def get_some_linear_layer(model):
|
||||
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
|
||||
@@ -855,3 +859,26 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
|
||||
|
||||
def test_fp4_double_safe(self):
|
||||
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
|
||||
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class Bnb4BitCompileTests(QuantCompileTests):
|
||||
quantization_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_group_offload(self):
|
||||
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
|
||||
|
||||
@@ -19,15 +19,18 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import (
|
||||
BitsAndBytesConfig,
|
||||
DiffusionPipeline,
|
||||
FluxControlPipeline,
|
||||
FluxTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
logging,
|
||||
)
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.utils import is_accelerate_version
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
@@ -39,14 +42,18 @@ from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_bitsandbytes_version_greater,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_version_greater_equal,
|
||||
require_transformers_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_torch_compile_utils import QuantCompileTests
|
||||
|
||||
|
||||
def get_some_linear_layer(model):
|
||||
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
|
||||
@@ -697,6 +704,50 @@ class SlowBnb8bitFluxTests(Base8bitTests):
|
||||
self.assertTrue(max_diff < 1e-3)
|
||||
|
||||
|
||||
@require_transformers_version_greater("4.44.0")
|
||||
@require_peft_backend
|
||||
class SlowBnb4BitFluxControlWithLoraTests(Base8bitTests):
|
||||
def setUp(self) -> None:
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
self.pipeline_8bit = FluxControlPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
),
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
self.pipeline_8bit.enable_model_cpu_offload()
|
||||
|
||||
def tearDown(self):
|
||||
del self.pipeline_8bit
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_lora_loading(self):
|
||||
self.pipeline_8bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
|
||||
|
||||
output = self.pipeline_8bit(
|
||||
prompt=self.prompt,
|
||||
control_image=Image.new(mode="RGB", size=(256, 256)),
|
||||
height=256,
|
||||
width=256,
|
||||
max_sequence_length=64,
|
||||
output_type="np",
|
||||
num_inference_steps=8,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images
|
||||
out_slice = output[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.2029, 0.2136, 0.2268, 0.1921, 0.1997, 0.2185, 0.2021, 0.2183, 0.2292])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
|
||||
|
||||
|
||||
@slow
|
||||
class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
def setUp(self):
|
||||
@@ -773,3 +824,27 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
out_0 = self.model_0(**inputs)[0]
|
||||
out_1 = model_1(**inputs)[0]
|
||||
self.assertTrue(torch.equal(out_0, out_1))
|
||||
|
||||
|
||||
@require_torch_version_greater_equal("2.6.0")
|
||||
class Bnb8BitCompileTests(QuantCompileTests):
|
||||
quantization_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
|
||||
def test_torch_compile_with_group_offload(self):
|
||||
super()._test_torch_compile_with_group_offload(
|
||||
quantization_config=self.quantization_config, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
87
tests/quantization/test_torch_compile_utils.py
Normal file
87
tests/quantization/test_torch_compile_utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class QuantCompileTests(unittest.TestCase):
|
||||
quantization_config = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.compiler.reset()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.compiler.reset()
|
||||
|
||||
def _init_pipeline(self, quantization_config, torch_dtype):
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return pipe
|
||||
|
||||
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
|
||||
# import to ensure fullgraph True
|
||||
pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.transformer.compile()
|
||||
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
pipe = self._init_pipeline(quantization_config, torch_dtype)
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch.device("cuda"),
|
||||
"offload_device": torch.device("cpu"),
|
||||
"offload_type": "leaf_level",
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
pipe.transformer.enable_group_offload(**group_offload_kwargs)
|
||||
pipe.transformer.compile()
|
||||
for name, component in pipe.components.items():
|
||||
if name != "transformer" and isinstance(component, torch.nn.Module):
|
||||
if torch.device(component.device).type == "cpu":
|
||||
component.to("cuda")
|
||||
|
||||
for _ in range(2):
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
|
||||
@@ -30,13 +30,15 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_synchronize,
|
||||
enable_full_determinism,
|
||||
is_torch_available,
|
||||
is_torchao_available,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
require_torchao_version_greater_or_equal,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -61,7 +63,7 @@ if is_torchao_available():
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
@@ -79,7 +81,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
Test kwargs validations in TorchAoConfig
|
||||
"""
|
||||
_ = TorchAoConfig("int4_weight_only")
|
||||
with self.assertRaisesRegex(ValueError, "is not supported yet"):
|
||||
with self.assertRaisesRegex(ValueError, "is not supported"):
|
||||
_ = TorchAoConfig("uint8")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
|
||||
@@ -119,12 +121,12 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_dummy_components(
|
||||
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
|
||||
@@ -269,6 +271,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=f"{torch_device}:0",
|
||||
)
|
||||
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
@@ -338,7 +341,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
@@ -359,7 +362,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
|
||||
@@ -518,14 +521,14 @@ class TorchAoTest(unittest.TestCase):
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
|
||||
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
|
||||
@@ -593,17 +596,17 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
def test_int_a8w8_cuda(self):
|
||||
def test_int_a8w8_accelerator(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = "cuda"
|
||||
device = torch_device
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
def test_int_a16w8_cuda(self):
|
||||
def test_int_a16w8_accelerator(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = "cuda"
|
||||
device = torch_device
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
@@ -624,14 +627,14 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_dummy_components(self, quantization_config: TorchAoConfig):
|
||||
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
|
||||
@@ -713,8 +716,8 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
|
||||
self._test_quant_type(quantization_config, expected_slice)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_synchronize(torch_device)
|
||||
|
||||
def test_serialization_int8wo(self):
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
@@ -733,8 +736,8 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
pipe.remove_all_hooks()
|
||||
del pipe.transformer
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_synchronize(torch_device)
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
|
||||
)
|
||||
@@ -783,14 +786,14 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
|
||||
if str(device).startswith("mps"):
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
Lumina2Transformer2DModel,
|
||||
)
|
||||
@@ -66,9 +64,9 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
FluxTransformer2DModel,
|
||||
)
|
||||
@@ -64,9 +62,9 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
SanaTransformer2DModel,
|
||||
)
|
||||
@@ -53,9 +51,9 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user