1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into xfail-failing-tests-pipeline

This commit is contained in:
Sayak Paul
2025-05-01 11:26:13 +08:00
committed by GitHub
5 changed files with 91 additions and 25 deletions

View File

@@ -134,7 +134,25 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
validation_image = validation_image.resize((args.resolution, args.resolution))
try:
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
except (AttributeError, KeyError):
supported_interpolation_modes = [
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
]
raise ValueError(
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
)
transform = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=interpolation),
transforms.CenterCrop(args.resolution),
]
)
validation_image = transform(validation_image)
images = []
@@ -587,6 +605,15 @@ def parse_args(input_args=None):
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -732,9 +759,20 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
def prepare_train_dataset(dataset, accelerator):
try:
interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
except (AttributeError, KeyError):
supported_interpolation_modes = [
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
]
raise ValueError(
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
)
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(args.resolution, interpolation=interpolation_mode),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
@@ -743,7 +781,7 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(args.resolution, interpolation=interpolation_mode),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
]

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
@@ -55,7 +55,7 @@ class ModuleGroup:
parameters: Optional[List[torch.nn.Parameter]] = None,
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
@@ -115,8 +115,13 @@ class ModuleGroup:
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
current_stream = torch.cuda.current_stream() if self.record_stream else None
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
@@ -162,9 +167,15 @@ class ModuleGroup:
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None:
if not self.record_stream:
torch.cuda.current_stream().synchronize()
torch_accelerator_module.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -429,8 +440,10 @@ def apply_group_offloading(
if use_stream:
if torch.cuda.is_available():
stream = torch.cuda.Stream()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
stream = torch.Stream()
else:
raise ValueError("Using streams for data transfer requires a CUDA device.")
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
@@ -468,7 +481,7 @@ def _apply_group_offloading_block_level(
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
@@ -486,7 +499,7 @@ def _apply_group_offloading_block_level(
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`, *optional*):
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
@@ -572,7 +585,7 @@ def _apply_group_offloading_leaf_level(
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
@@ -592,7 +605,7 @@ def _apply_group_offloading_leaf_level(
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`, *optional*):
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor

View File

@@ -22,6 +22,7 @@ from parameterized import parameterized
from diffusers import AsymmetricAutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -134,18 +135,32 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
# fmt: off
[
33,
[-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
Expectations(
{
("xpu", 3): torch.tensor([-0.0343, 0.2873, 0.1680, -0.0140, -0.3459, 0.3522, -0.1336, 0.1075]),
("cuda", 7): torch.tensor([-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205]),
("mps", None): torch.tensor(
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]
),
}
),
],
[
47,
[0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
Expectations(
{
("xpu", 3): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
("cuda", 7): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
("mps", None): torch.tensor(
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]
),
}
),
],
# fmt: on
]
)
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
def test_stable_diffusion(self, seed, expected_slices):
model = self.get_sd_vae_model()
image = self.get_sd_image(seed)
generator = self.get_generator(seed)
@@ -156,9 +171,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
expected_slice = expected_slices.get_expectation()
assert torch_all_close(output_slice, expected_slice, atol=5e-3)
@parameterized.expand(
[

View File

@@ -35,7 +35,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_big_accelerator,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
@@ -210,8 +210,8 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
@nightly
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline

View File

@@ -24,7 +24,7 @@ from diffusers import (
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_big_gpu_with_torch_cuda,
require_big_accelerator,
require_torch_accelerator,
torch_device,
)
@@ -62,7 +62,7 @@ class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
)
@require_big_gpu_with_torch_cuda
@require_big_accelerator
@require_torch_accelerator
class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
model_class = WanTransformer3DModel