diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md index b1c130b792..dc095054e1 100644 --- a/docs/source/en/quantization/bitsandbytes.md +++ b/docs/source/en/quantization/bitsandbytes.md @@ -416,6 +416,45 @@ text_encoder_2_4bit.dequantize() transformer_4bit.dequantize() ``` +## torch.compile + +Speed up inference with `torch.compile`. Make sure you have the latest `bitsandbytes` installed and we also recommend installing [PyTorch nightly](https://pytorch.org/get-started/locally/). + + + +```py +torch._dynamo.config.capture_dynamic_output_shape_ops = True + +quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) +transformer_4bit = AutoModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quant_config, + torch_dtype=torch.float16, +) +transformer_4bit.compile(fullgraph=True) +``` + + + + +```py +quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True) +transformer_4bit = AutoModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quant_config, + torch_dtype=torch.float16, +) +transformer_4bit.compile(fullgraph=True) +``` + + + +On an RTX 4090 with compilation, 4-bit Flux generation completed in 25.809 seconds versus 32.570 seconds without. + +Check out the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) for more details. + ## Resources * [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 6092eeff0a..189a9ceba5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -81,12 +81,17 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): from ..quantizers.gguf.utils import dequantize_gguf_tensor is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" + is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params" is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" if is_bnb_4bit_quantized and not is_bitsandbytes_available(): raise ValueError( "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." ) + if is_bnb_8bit_quantized and not is_bitsandbytes_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints." + ) if is_gguf_quantized and not is_gguf_available(): raise ValueError( "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." @@ -97,10 +102,10 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): weight_on_cpu = True device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" - if is_bnb_4bit_quantized: + if is_bnb_4bit_quantized or is_bnb_8bit_quantized: module_weight = dequantize_bnb_weight( module.weight.to(device) if weight_on_cpu else module.weight, - state=module.weight.quant_state, + state=module.weight.quant_state if is_bnb_4bit_quantized else module.state, dtype=model.dtype, ).data elif is_gguf_quantized: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c25e9997e3..09e3621c2c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1149,9 +1149,7 @@ def get_1d_rotary_pos_embed( theta = theta * ntk_factor freqs = ( - 1.0 - / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) - / linear_factor + 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] is_npu = freqs.device.type == "npu" diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index c71a8b3b5a..beaea48050 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -816,14 +816,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Defaults to `None`, meaning that the model will be loaded on CPU. + Examples: + + ```py + >>> from diffusers import AutoModel + >>> import torch + + >>> # This works. + >>> model = AutoModel.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda" + ... ) + >>> # This also works (integer accelerator device ID). + >>> model = AutoModel.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0 + ... ) + >>> # Specifying a supported offloading strategy like "auto" also works. + >>> model = AutoModel.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto" + ... ) + >>> # Specifying a dictionary as `device_map` also works. + >>> model = AutoModel.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... subfolder="unet", + ... device_map={"": torch.device("cuda")}, + ... ) + ``` + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You + can also refer to the [Diffusers-specific + documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding) + for more concrete examples. max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. @@ -1389,7 +1418,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): low_cpu_mem_usage: bool = True, dtype: Optional[Union[str, torch.dtype]] = None, keep_in_fp32_modules: Optional[List[str]] = None, - device_map: Dict[str, Union[int, str, torch.device]] = None, + device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None, offload_state_dict: Optional[bool] = None, offload_folder: Optional[Union[str, os.PathLike]] = None, dduf_entries: Optional[Dict[str, DDUFEntry]] = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 6dbcd5c6db..b9d0b92561 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -898,6 +898,7 @@ class FluxPipeline( ) # 6. Denoising loop + self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index a67eb5d0d6..29c763a6b1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -1193,6 +1193,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) + if padding_mask_crop is not None: + image = [ + self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image + ] + # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0ac4251ec6..efeb085a72 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -669,14 +669,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): - A map that specifies where each submodule should go. It doesn’t need to be defined for each - parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the - same device. - - Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For - more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + device_map (`str`, *optional*): + Strategy that dictates how the different components of a pipeline should be placed on available + devices. Currently, only "balanced" `device_map` is supported. Check out + [this](https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement) + to know more. max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 3c0ac30bb6..6df66118b0 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -388,8 +388,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). height (`int`, defaults to `480`): The height in pixels of the generated image. width (`int`, defaults to `832`): @@ -434,8 +436,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): - The dtype to use for the torch.amp.autocast. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. Examples: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 77f0e4d56a..c71138a97d 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -562,12 +562,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int`, *optional*, defaults to `512`): - The maximum sequence length of the prompt. - shift (`float`, *optional*, defaults to `5.0`): - The shift of the flow. - autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): - The dtype to use for the torch.amp.autocast. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + Examples: Returns: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index e029006aa1..a0b5ed93c9 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -687,8 +687,33 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + video (`List[PIL.Image.Image]`, *optional*): + The input video or videos to be used as a starting point for the generation. The video should be a list + of PIL images, a numpy array, or a torch tensor. Currently, the pipeline only supports generating one + video at a time. + mask (`List[PIL.Image.Image]`, *optional*): + The input mask defines which video regions to condition on and which to generate. Black areas in the + mask indicate conditioning regions, while white areas indicate regions for generation. The mask should + be a list of PIL images, a numpy array, or a torch tensor. Currently supports generating a single video + at a time. + reference_images (`List[PIL.Image.Image]`, *optional*): + A list of one or more reference images as extra conditioning for the generation. For example, if you + are trying to inpaint a video to change the character, you can pass reference images of the new + character here. Refer to the Diffusers [examples](https://github.com/huggingface/diffusers/pull/11582) + and original [user + guide](https://github.com/ali-vilab/VACE/blob/0897c6d055d7d9ea9e191dce763006664d9780f8/UserGuide.md) + for a full list of supported tasks and use cases. + conditioning_scale (`float`, `List[float]`, `torch.Tensor`, defaults to `1.0`): + The conditioning scale to be applied when adding the control conditioning latent stream to the + denoising latent stream in each control layer of the model. If a float is provided, it will be applied + uniformly to all layers. If a list or tensor is provided, it should have the same length as the number + of control layers in the model (`len(transformer.config.vace_layers)`). height (`int`, defaults to `480`): The height in pixels of the generated image. width (`int`, defaults to `832`): @@ -733,8 +758,9 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): - The dtype to use for the torch.amp.autocast. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. Examples: diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index 1844f1b49b..1a2d2e9c22 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -419,12 +419,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ) if latents is None: - if isinstance(generator, list): - init_latents = [ - retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) - ] - else: - init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video] init_latents = torch.cat(init_latents, dim=0).to(dtype) @@ -441,7 +436,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): if hasattr(self.scheduler, "add_noise"): latents = self.scheduler.add_noise(init_latents, noise, timestep) else: - latents = self.scheduelr.scale_noise(init_latents, timestep, noise) + latents = self.scheduler.scale_noise(init_latents, timestep, noise) else: latents = latents.to(device) @@ -513,7 +508,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead. height (`int`, defaults to `480`): The height in pixels of the generated image. @@ -530,6 +525,8 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + strength (`float`, defaults to `0.8`): + Higher strength leads to more differences between original image and generated video. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -559,8 +556,9 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): - The dtype to use for the torch.amp.autocast. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. Examples: diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 609c9ad15a..871faf076e 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -493,7 +493,7 @@ class TorchAoConfig(QuantizationConfigMixin): TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") - if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9(): + if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): raise ValueError( f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You " f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`." @@ -645,7 +645,7 @@ class TorchAoConfig(QuantizationConfigMixin): QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) - if cls._is_cuda_capability_atleast_8_9(): + if cls._is_xpu_or_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) return QUANTIZATION_TYPES @@ -655,14 +655,16 @@ class TorchAoConfig(QuantizationConfigMixin): ) @staticmethod - def _is_cuda_capability_atleast_8_9() -> bool: - if not torch.cuda.is_available(): - raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.") - - major, minor = torch.cuda.get_device_capability() - if major == 8: - return minor >= 9 - return major >= 9 + def _is_xpu_or_cuda_capability_atleast_8_9() -> bool: + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + if major == 8: + return minor >= 9 + return major >= 9 + elif torch.xpu.is_available(): + return True + else: + raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.") def get_apply_tensor_subclass(self): TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 5d0752af89..4878937ab2 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -154,12 +154,30 @@ def check_imports(filename): return get_relative_imports(filename) -def get_class_in_module(class_name, module_path): +def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None): """ Import a module on the cache directory for modules and extract a class from it. """ module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as e: + # This can happen when the repo id contains ".", which Python's import machinery interprets as a directory + # separator. We do a bit of monkey patching to detect and fix this case. + if not ( + pretrained_model_name_or_path is not None + and "." in pretrained_model_name_or_path + and module_path.startswith("diffusers_modules") + and pretrained_model_name_or_path.replace("/", "--") in module_path + ): + raise e # We can't figure this one out, just reraise the original error + + corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py" + corrected_path = corrected_path.replace( + pretrained_model_name_or_path.replace("/", "--").replace(".", "/"), + pretrained_model_name_or_path.replace("/", "--"), + ) + module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module() if class_name is None: return find_pipeline_class(module) @@ -454,4 +472,4 @@ def get_class_from_dynamic_module( revision=revision, local_files_only=local_files_only, ) - return get_class_in_module(class_name, final_module.replace(".py", "")) + return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f7244e97b8..b64cecc412 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -99,6 +99,7 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA else: logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False + _torch_version = "N/A" _jax_version = "N/A" _flax_version = "N/A" diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index e19a9f83fd..5cbe5ff277 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -291,6 +291,18 @@ def require_torch_version_greater_equal(torch_version): return decorator +def require_torch_version_greater(torch_version): + """Decorator marking a test that requires torch with a specific version greater.""" + + def decorator(test_case): + correct_torch_version = is_torch_available() and is_torch_version(">", torch_version) + return unittest.skipUnless( + correct_torch_version, f"test requires torch with the version greater than {torch_version}" + )(test_case) + + return decorator + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( @@ -300,9 +312,7 @@ def require_torch_gpu(test_case): def require_torch_cuda_compatibility(expected_compute_capability): def decorator(test_case): - if not torch.cuda.is_available(): - return unittest.skip(test_case) - else: + if torch.cuda.is_available(): current_compute_capability = get_torch_cuda_device_capability() return unittest.skipUnless( float(current_compute_capability) == float(expected_compute_capability), diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index 77977a78d8..db87004fcb 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -21,6 +21,7 @@ import torch from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, load_image, slow, @@ -162,13 +163,13 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) @torch.no_grad() def test_encode_decode(self): diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index 0e5fdc4bba..1a7959a877 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -22,6 +22,7 @@ import torch from diffusers import UNet2DModel from diffusers.utils import logging from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, require_torch_accelerator, @@ -229,7 +230,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): # two models don't need to stay in the device at the same time del model_accelerate - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() model_normal_load, _ = UNet2DModel.from_pretrained( diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index ab0dcbc1de..e0331d15dd 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -978,13 +978,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) - @require_torch_gpu @parameterized.expand( [ ("hf-internal-testing/unet2d-sharded-dummy", None), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), ] ) + @require_torch_accelerator def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) @@ -994,13 +994,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) - @require_torch_gpu @parameterized.expand( [ ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), ] ) + @require_torch_accelerator def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) @@ -1084,6 +1084,42 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @parameterized.expand( + [ + (-1, "You can't pass device_map as a negative int"), + ("foo", "When passing device_map as a string, the value needs to be a device name"), + ] + ) + def test_wrong_device_map_raises_error(self, device_map, msg_substring): + with self.assertRaises(ValueError) as err_ctx: + _ = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map + ) + + assert msg_substring in str(err_ctx.exception) + + @parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")]) + @require_torch_gpu + def test_passing_non_dict_device_map_works(self, device_map): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map + ) + output = loaded_model(**inputs_dict) + assert output.sample.shape == (4, 4, 16, 16) + + @parameterized.expand([("", "cuda"), ("", torch.device("cuda"))]) + @require_torch_gpu + def test_passing_dict_device_map_works(self, name, device_map): + # There are other valid dict-based `device_map` values too. It's best to refer to + # the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap. + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map} + ) + output = loaded_model(**inputs_dict) + assert output.sample.shape == (4, 4, 16, 16) + @require_peft_backend def test_load_attn_procs_raise_warning(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index 30fdd68cfd..30a14ef7f5 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -24,6 +24,7 @@ from transformers import AutoTokenizer, T5Config, T5EncoderModel from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, require_hf_hub_version_greater, @@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_allegro(self): generator = torch.Generator("cpu").manual_seed(0) diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py index aaf44985aa..340bc24adc 100644 --- a/tests/pipelines/audioldm/test_audioldm.py +++ b/tests/pipelines/audioldm/test_audioldm.py @@ -37,7 +37,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils import is_xformers_available -from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device +from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -378,12 +378,12 @@ class AudioLDMPipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): generator = torch.Generator(device=generator_device).manual_seed(seed) @@ -423,12 +423,12 @@ class AudioLDMPipelineNightlyTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): generator = torch.Generator(device=generator_device).manual_seed(seed) diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index a8f60fb6dc..14b5510fca 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -45,7 +45,13 @@ from diffusers import ( LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.utils.testing_utils import enable_full_determinism, is_torch_version, nightly, torch_device +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + is_torch_version, + nightly, + torch_device, +) from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -540,12 +546,12 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): generator = torch.Generator(device=generator_device).manual_seed(seed) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index a9de0ff05f..a6349c99c5 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, require_torch_accelerator, @@ -334,12 +335,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_cogvideox(self): generator = torch.Generator("cpu").manual_seed(0) diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py index 79dffd230a..4eca68dd7b 100644 --- a/tests/pipelines/cogview3/test_cogview3plus.py +++ b/tests/pipelines/cogview3/test_cogview3plus.py @@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, require_torch_accelerator, @@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_cogview3plus(self): generator = torch.Generator("cpu").manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 100765ee34..0147d4a651 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo from diffusers.utils import load_image from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_numpy, @@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_canny(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index b06590e13c..63d5fd4660 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo from diffusers.utils import load_image from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_numpy, @@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_canny(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 1be15645ef..7880f744b9 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -221,7 +221,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3ControlNetPipeline diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py index 295b29f12e..8445229079 100644 --- a/tests/pipelines/deepfloyd_if/test_if.py +++ b/tests/pipelines/deepfloyd_if/test_if.py @@ -25,6 +25,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( backend_empty_cache, + backend_max_memory_allocated, backend_reset_max_memory_allocated, backend_reset_peak_memory_stats, load_numpy, @@ -135,7 +136,7 @@ class IFPipelineSlowTests(unittest.TestCase): image = output.images[0] - mem_bytes = torch.cuda.max_memory_allocated() + mem_bytes = backend_max_memory_allocated(torch_device) assert mem_bytes < 12 * 10**9 expected_image = load_numpy( diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py index da06dc3558..14271a9862 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py @@ -24,6 +24,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( backend_empty_cache, + backend_max_memory_allocated, backend_reset_max_memory_allocated, backend_reset_peak_memory_stats, floats_tensor, @@ -151,7 +152,7 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase): ) image = output.images[0] - mem_bytes = torch.cuda.max_memory_allocated() + mem_bytes = backend_max_memory_allocated(torch_device) assert mem_bytes < 12 * 10**9 expected_image = load_numpy( diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 646ad928ec..cbdf617d71 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -224,7 +224,7 @@ class FluxPipelineFastTests( @nightly @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class FluxPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-schnell" @@ -312,7 +312,7 @@ class FluxPipelineSlowTests(unittest.TestCase): @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class FluxIPAdapterPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-dev" diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py index 1f204add1c..b8f36dfd3c 100644 --- a/tests/pipelines/flux/test_pipeline_flux_redux.py +++ b/tests/pipelines/flux/test_pipeline_flux_redux.py @@ -19,7 +19,7 @@ from diffusers.utils.testing_utils import ( @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class FluxReduxSlowTests(unittest.TestCase): pipeline_class = FluxPriorReduxPipeline repo_id = "black-forest-labs/FLUX.1-Redux-dev" diff --git a/tests/pipelines/hunyuandit/test_hunyuan_dit.py b/tests/pipelines/hunyuandit/test_hunyuan_dit.py index 66453b73b0..05c94262ab 100644 --- a/tests/pipelines/hunyuandit/test_hunyuan_dit.py +++ b/tests/pipelines/hunyuandit/test_hunyuan_dit.py @@ -23,6 +23,7 @@ from transformers import AutoTokenizer, BertModel, T5EncoderModel from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, require_torch_accelerator, @@ -310,12 +311,12 @@ class HunyuanDiTPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_hunyuan_dit_1024(self): generator = torch.Generator("cpu").manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index c66491b15c..8399e57bfb 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -27,6 +27,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils.testing_utils import ( + backend_empty_cache, load_numpy, nightly, numpy_cosine_similarity_distance, @@ -231,12 +232,12 @@ class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_attend_and_excite_fp16(self): generator = torch.manual_seed(51) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 2feeaaf11c..ff4a33abf8 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -25,6 +25,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel from diffusers.utils.testing_utils import ( backend_empty_cache, + backend_max_memory_allocated, backend_reset_max_memory_allocated, backend_reset_peak_memory_stats, enable_full_determinism, @@ -287,6 +288,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): output_type="np", ) - mem_bytes = torch.cuda.max_memory_allocated() + mem_bytes = backend_max_memory_allocated(torch_device) # make sure that less than 2.65 GB is allocated assert mem_bytes < 2.65 * 10**9 diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 8e2fa77fc0..577ac4ebdd 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -233,7 +233,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class StableDiffusion3PipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Pipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py index 80bb35a08e..f5b5e63a81 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py @@ -168,7 +168,7 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index a41e7dc7f3..11f08c8820 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -35,6 +35,7 @@ from diffusers import ( UniPCMultistepScheduler, ) from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, load_image, numpy_cosine_similarity_distance, @@ -940,12 +941,12 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_stable_diffusion_lcm(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 9a141634a3..7d19d745a2 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -39,6 +39,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_image, @@ -670,12 +671,12 @@ class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_stable_diffusion_xl_img2img_playground(self): torch.manual_seed(0) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index f1d9d244e5..c4db662784 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1105,6 +1105,21 @@ class CustomPipelineTests(unittest.TestCase): assert images.shape == (1, 64, 64, 3) + def test_remote_custom_pipe_with_dot_in_name(self): + # make sure that trust remote code has to be passed + with self.assertRaises(ValueError): + pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name") + + pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name", trust_remote_code=True) + + assert pipeline.__class__.__name__ == "CustomPipeline" + + pipeline = pipeline.to(torch_device) + images, output_str = pipeline(num_inference_steps=2, output_type="np") + + assert images[0].shape == (1, 32, 32, 3) + assert output_str == "This is a test" + def test_local_custom_pipeline_repo(self): local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") pipeline = DiffusionPipeline.from_pretrained( @@ -1203,13 +1218,13 @@ class PipelineFastTests(unittest.TestCase): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def dummy_image(self): batch_size = 1 diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py index a162e6841d..e3a153dd19 100644 --- a/tests/pipelines/wan/test_wan.py +++ b/tests/pipelines/wan/test_wan.py @@ -21,9 +21,11 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, require_torch_accelerator, slow, + torch_device, ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -144,12 +146,12 @@ class WanPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) @unittest.skip("TODO: test needs to be implemented") def test_Wanx(self): diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index ac1b0cf3ce..2d8b9f698b 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -30,6 +30,7 @@ from diffusers import ( FluxTransformer2DModel, SD3Transformer2DModel, ) +from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils import is_accelerate_version, logging from diffusers.utils.testing_utils import ( CaptureLogger, @@ -44,11 +45,14 @@ from diffusers.utils.testing_utils import ( require_peft_backend, require_torch, require_torch_accelerator, + require_torch_version_greater, require_transformers_version_greater, slow, torch_device, ) +from ..test_torch_compile_utils import QuantCompileTests + def get_some_linear_layer(model): if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: @@ -855,3 +859,26 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests): def test_fp4_double_safe(self): self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) + + +@require_torch_version_greater("2.7.1") +class Bnb4BitCompileTests(QuantCompileTests): + quantization_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + }, + components_to_quantize=["transformer", "text_encoder_2"], + ) + + def test_torch_compile(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile(quantization_config=self.quantization_config) + + def test_torch_compile_with_cpu_offload(self): + super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) + + def test_torch_compile_with_group_offload(self): + super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 98575b86cd..b15a9f72a8 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -19,15 +19,18 @@ import unittest import numpy as np import pytest from huggingface_hub import hf_hub_download +from PIL import Image from diffusers import ( BitsAndBytesConfig, DiffusionPipeline, + FluxControlPipeline, FluxTransformer2DModel, SanaTransformer2DModel, SD3Transformer2DModel, logging, ) +from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils import is_accelerate_version from diffusers.utils.testing_utils import ( CaptureLogger, @@ -39,14 +42,18 @@ from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, require_accelerate, require_bitsandbytes_version_greater, + require_peft_backend, require_peft_version_greater, require_torch, require_torch_accelerator, + require_torch_version_greater_equal, require_transformers_version_greater, slow, torch_device, ) +from ..test_torch_compile_utils import QuantCompileTests + def get_some_linear_layer(model): if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: @@ -697,6 +704,50 @@ class SlowBnb8bitFluxTests(Base8bitTests): self.assertTrue(max_diff < 1e-3) +@require_transformers_version_greater("4.44.0") +@require_peft_backend +class SlowBnb4BitFluxControlWithLoraTests(Base8bitTests): + def setUp(self) -> None: + gc.collect() + backend_empty_cache(torch_device) + + self.pipeline_8bit = FluxControlPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantization_config=PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={"load_in_8bit": True}, + components_to_quantize=["transformer", "text_encoder_2"], + ), + torch_dtype=torch.float16, + ) + self.pipeline_8bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_8bit + + gc.collect() + backend_empty_cache(torch_device) + + def test_lora_loading(self): + self.pipeline_8bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") + + output = self.pipeline_8bit( + prompt=self.prompt, + control_image=Image.new(mode="RGB", size=(256, 256)), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + num_inference_steps=8, + generator=torch.Generator().manual_seed(42), + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.2029, 0.2136, 0.2268, 0.1921, 0.1997, 0.2185, 0.2021, 0.2183, 0.2292]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}") + + @slow class BaseBnb8bitSerializationTests(Base8bitTests): def setUp(self): @@ -773,3 +824,27 @@ class BaseBnb8bitSerializationTests(Base8bitTests): out_0 = self.model_0(**inputs)[0] out_1 = model_1(**inputs)[0] self.assertTrue(torch.equal(out_0, out_1)) + + +@require_torch_version_greater_equal("2.6.0") +class Bnb8BitCompileTests(QuantCompileTests): + quantization_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={"load_in_8bit": True}, + components_to_quantize=["transformer", "text_encoder_2"], + ) + + def test_torch_compile(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16) + + def test_torch_compile_with_cpu_offload(self): + super()._test_torch_compile_with_cpu_offload( + quantization_config=self.quantization_config, torch_dtype=torch.float16 + ) + + @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") + def test_torch_compile_with_group_offload(self): + super()._test_torch_compile_with_group_offload( + quantization_config=self.quantization_config, torch_dtype=torch.float16 + ) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py new file mode 100644 index 0000000000..1ae77b27d7 --- /dev/null +++ b/tests/quantization/test_torch_compile_utils.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import unittest + +import torch + +from diffusers import DiffusionPipeline +from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device + + +@require_torch_gpu +@slow +class QuantCompileTests(unittest.TestCase): + quantization_config = None + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def _init_pipeline(self, quantization_config, torch_dtype): + pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + quantization_config=quantization_config, + torch_dtype=torch_dtype, + ) + return pipe + + def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16): + pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda") + # import to ensure fullgraph True + pipe.transformer.compile(fullgraph=True) + + for _ in range(2): + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) + + def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16): + pipe = self._init_pipeline(quantization_config, torch_dtype) + pipe.enable_model_cpu_offload() + pipe.transformer.compile() + + for _ in range(2): + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) + + def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16): + torch._dynamo.config.cache_size_limit = 10000 + + pipe = self._init_pipeline(quantization_config, torch_dtype) + group_offload_kwargs = { + "onload_device": torch.device("cuda"), + "offload_device": torch.device("cpu"), + "offload_type": "leaf_level", + "use_stream": True, + "non_blocking": True, + } + pipe.transformer.enable_group_offload(**group_offload_kwargs) + pipe.transformer.compile() + for name, component in pipe.components.items(): + if name != "transformer" and isinstance(component, torch.nn.Module): + if torch.device(component.device).type == "cpu": + component.to("cuda") + + for _ in range(2): + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 0e671307dd..743da17356 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -30,13 +30,15 @@ from diffusers import ( ) from diffusers.models.attention_processor import Attention from diffusers.utils.testing_utils import ( + backend_empty_cache, + backend_synchronize, enable_full_determinism, is_torch_available, is_torchao_available, nightly, numpy_cosine_similarity_distance, require_torch, - require_torch_gpu, + require_torch_accelerator, require_torchao_version_greater_or_equal, slow, torch_device, @@ -61,7 +63,7 @@ if is_torchao_available(): @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") class TorchAoConfigTest(unittest.TestCase): def test_to_dict(self): @@ -79,7 +81,7 @@ class TorchAoConfigTest(unittest.TestCase): Test kwargs validations in TorchAoConfig """ _ = TorchAoConfig("int4_weight_only") - with self.assertRaisesRegex(ValueError, "is not supported yet"): + with self.assertRaisesRegex(ValueError, "is not supported"): _ = TorchAoConfig("uint8") with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): @@ -119,12 +121,12 @@ class TorchAoConfigTest(unittest.TestCase): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_components( self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe" @@ -269,6 +271,7 @@ class TorchAoTest(unittest.TestCase): subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + device_map=f"{torch_device}:0", ) weight = quantized_model.transformer_blocks[0].ff.net[2].weight @@ -338,7 +341,7 @@ class TorchAoTest(unittest.TestCase): output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3) with tempfile.TemporaryDirectory() as offload_folder: quantization_config = TorchAoConfig("int4_weight_only", group_size=64) @@ -359,7 +362,7 @@ class TorchAoTest(unittest.TestCase): output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3) def test_modules_to_not_convert(self): quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) @@ -518,14 +521,14 @@ class TorchAoTest(unittest.TestCase): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") class TorchAoSerializationTest(unittest.TestCase): model_name = "hf-internal-testing/tiny-flux-pipe" def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_model(self, quant_method, quant_method_kwargs, device=None): quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs) @@ -593,17 +596,17 @@ class TorchAoSerializationTest(unittest.TestCase): ) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) - def test_int_a8w8_cuda(self): + def test_int_a8w8_accelerator(self): quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - device = "cuda" + device = torch_device self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) - def test_int_a16w8_cuda(self): + def test_int_a16w8_accelerator(self): quant_method, quant_method_kwargs = "int8_weight_only", {} expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) - device = "cuda" + device = torch_device self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) @@ -624,14 +627,14 @@ class TorchAoSerializationTest(unittest.TestCase): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") @slow @nightly class SlowTorchAoTests(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_components(self, quantization_config: TorchAoConfig): # This is just for convenience, so that we can modify it at one place for custom environments and locally testing @@ -713,8 +716,8 @@ class SlowTorchAoTests(unittest.TestCase): quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) self._test_quant_type(quantization_config, expected_slice) gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() + backend_empty_cache(torch_device) + backend_synchronize(torch_device) def test_serialization_int8wo(self): quantization_config = TorchAoConfig("int8wo") @@ -733,8 +736,8 @@ class SlowTorchAoTests(unittest.TestCase): pipe.remove_all_hooks() del pipe.transformer gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() + backend_empty_cache(torch_device) + backend_synchronize(torch_device) transformer = FluxTransformer2DModel.from_pretrained( tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False ) @@ -783,14 +786,14 @@ class SlowTorchAoTests(unittest.TestCase): @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") @slow @nightly class SlowTorchAoPreserializedModelTests(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_inputs(self, device: torch.device, seed: int = 0): if str(device).startswith("mps"): diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py index d3ffd4fc3a..2ac681897d 100644 --- a/tests/single_file/test_lumina2_transformer.py +++ b/tests/single_file/test_lumina2_transformer.py @@ -16,8 +16,6 @@ import gc import unittest -import torch - from diffusers import ( Lumina2Transformer2DModel, ) @@ -66,9 +64,9 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase): def test_checkpoint_loading(self): for ckpt_path in self.alternate_keys_ckpt_paths: - torch.cuda.empty_cache() + backend_empty_cache(torch_device) model = self.model_class.from_single_file(ckpt_path) del model gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py index bf11faaa9c..81779cf8fa 100644 --- a/tests/single_file/test_model_flux_transformer_single_file.py +++ b/tests/single_file/test_model_flux_transformer_single_file.py @@ -16,8 +16,6 @@ import gc import unittest -import torch - from diffusers import ( FluxTransformer2DModel, ) @@ -64,9 +62,9 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase): def test_checkpoint_loading(self): for ckpt_path in self.alternate_keys_ckpt_paths: - torch.cuda.empty_cache() + backend_empty_cache(torch_device) model = self.model_class.from_single_file(ckpt_path) del model gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py index 802ca37abf..e74c5be6ff 100644 --- a/tests/single_file/test_sana_transformer.py +++ b/tests/single_file/test_sana_transformer.py @@ -1,8 +1,6 @@ import gc import unittest -import torch - from diffusers import ( SanaTransformer2DModel, ) @@ -53,9 +51,9 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase): def test_checkpoint_loading(self): for ckpt_path in self.alternate_keys_ckpt_paths: - torch.cuda.empty_cache() + backend_empty_cache(torch_device) model = self.model_class.from_single_file(ckpt_path) del model gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device)