diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 8b87db958d..3368db1ec0 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -134,7 +134,25 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") - validation_image = validation_image.resize((args.resolution, args.resolution)) + + try: + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper()) + except (AttributeError, KeyError): + supported_interpolation_modes = [ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ] + raise ValueError( + f"Interpolation mode {args.image_interpolation_mode} is not supported. " + f"Please select one of the following: {', '.join(supported_interpolation_modes)}" + ) + + transform = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=interpolation), + transforms.CenterCrop(args.resolution), + ] + ) + validation_image = transform(validation_image) images = [] @@ -587,6 +605,15 @@ def parse_args(input_args=None): " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -732,9 +759,20 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom def prepare_train_dataset(dataset, accelerator): + try: + interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper()) + except (AttributeError, KeyError): + supported_interpolation_modes = [ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ] + raise ValueError( + f"Interpolation mode {args.image_interpolation_mode} is not supported. " + f"Please select one of the following: {', '.join(supported_interpolation_modes)}" + ) + image_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation_mode), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), @@ -743,7 +781,7 @@ def prepare_train_dataset(dataset, accelerator): conditioning_image_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation_mode), transforms.CenterCrop(args.resolution), transforms.ToTensor(), ] diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index d88114436d..a2c2e2430c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager, nullcontext -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union import torch @@ -55,7 +55,7 @@ class ModuleGroup: parameters: Optional[List[torch.nn.Parameter]] = None, buffers: Optional[List[torch.Tensor]] = None, non_blocking: bool = False, - stream: Optional[torch.cuda.Stream] = None, + stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, onload_self: bool = True, @@ -115,8 +115,13 @@ class ModuleGroup: def onload_(self): r"""Onloads the group of modules to the onload_device.""" - context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) - current_stream = torch.cuda.current_stream() if self.record_stream else None + torch_accelerator_module = ( + getattr(torch, torch.accelerator.current_accelerator().type) + if hasattr(torch, "accelerator") + else torch.cuda + ) + context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) + current_stream = torch_accelerator_module.current_stream() if self.record_stream else None if self.stream is not None: # Wait for previous Host->Device transfer to complete @@ -162,9 +167,15 @@ class ModuleGroup: def offload_(self): r"""Offloads the group of modules to the offload_device.""" + + torch_accelerator_module = ( + getattr(torch, torch.accelerator.current_accelerator().type) + if hasattr(torch, "accelerator") + else torch.cuda + ) if self.stream is not None: if not self.record_stream: - torch.cuda.current_stream().synchronize() + torch_accelerator_module.current_stream().synchronize() for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] @@ -429,8 +440,10 @@ def apply_group_offloading( if use_stream: if torch.cuda.is_available(): stream = torch.cuda.Stream() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + stream = torch.Stream() else: - raise ValueError("Using streams for data transfer requires a CUDA device.") + raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") _raise_error_if_accelerate_model_or_sequential_hook_present(module) @@ -468,7 +481,7 @@ def _apply_group_offloading_block_level( offload_device: torch.device, onload_device: torch.device, non_blocking: bool, - stream: Optional[torch.cuda.Stream] = None, + stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, ) -> None: @@ -486,7 +499,7 @@ def _apply_group_offloading_block_level( non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. - stream (`torch.cuda.Stream`, *optional*): + stream (`torch.cuda.Stream`or `torch.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor @@ -572,7 +585,7 @@ def _apply_group_offloading_leaf_level( offload_device: torch.device, onload_device: torch.device, non_blocking: bool, - stream: Optional[torch.cuda.Stream] = None, + stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, ) -> None: @@ -592,7 +605,7 @@ def _apply_group_offloading_leaf_level( non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. - stream (`torch.cuda.Stream`, *optional*): + stream (`torch.cuda.Stream` or `torch.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor diff --git a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py index 7efb390287..9622850766 100644 --- a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py @@ -22,6 +22,7 @@ from parameterized import parameterized from diffusers import AsymmetricAutoencoderKL from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( + Expectations, backend_empty_cache, enable_full_determinism, floats_tensor, @@ -134,18 +135,32 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase): # fmt: off [ 33, - [-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205], - [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824], + Expectations( + { + ("xpu", 3): torch.tensor([-0.0343, 0.2873, 0.1680, -0.0140, -0.3459, 0.3522, -0.1336, 0.1075]), + ("cuda", 7): torch.tensor([-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205]), + ("mps", None): torch.tensor( + [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824] + ), + } + ), ], [ 47, - [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529], - [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089], + Expectations( + { + ("xpu", 3): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]), + ("cuda", 7): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]), + ("mps", None): torch.tensor( + [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089] + ), + } + ), ], # fmt: on ] ) - def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps): + def test_stable_diffusion(self, seed, expected_slices): model = self.get_sd_vae_model() image = self.get_sd_image(seed) generator = self.get_generator(seed) @@ -156,9 +171,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase): assert sample.shape == image.shape output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice) - assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + expected_slice = expected_slices.get_expectation() + assert torch_all_close(output_slice, expected_slice, atol=5e-3) @parameterized.expand( [ diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 9ce62cde9f..a64b3c66ea 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -35,7 +35,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, nightly, numpy_cosine_similarity_distance, - require_big_gpu_with_torch_cuda, + require_big_accelerator, torch_device, ) from diffusers.utils.torch_utils import randn_tensor @@ -210,8 +210,8 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl @nightly -@require_big_gpu_with_torch_cuda -@pytest.mark.big_gpu_with_torch_cuda +@require_big_accelerator +@pytest.mark.big_accelerator class FluxControlNetPipelineSlowTests(unittest.TestCase): pipeline_class = FluxControlNetPipeline diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py index 36f0919cac..72b4b3a58a 100644 --- a/tests/single_file/test_model_wan_transformer3d_single_file.py +++ b/tests/single_file/test_model_wan_transformer3d_single_file.py @@ -24,7 +24,7 @@ from diffusers import ( from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, - require_big_gpu_with_torch_cuda, + require_big_accelerator, require_torch_accelerator, torch_device, ) @@ -62,7 +62,7 @@ class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase): ) -@require_big_gpu_with_torch_cuda +@require_big_accelerator @require_torch_accelerator class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase): model_class = WanTransformer3DModel