1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] fix: lora loading when using with a device_mapped model. (#9449)

* fix: lora loading when using with a device_mapped model.

* better attibutung

* empty

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* minors

* better error messages.

* fix-copies

* add: tests, docs.

* add hardware note.

* quality

* Update docs/source/en/training/distributed_inference.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fixes

* skip properly.

* fixes

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Sayak Paul
2024-10-31 21:17:41 +05:30
committed by GitHub
parent ff182ad669
commit 41e4779d98
22 changed files with 546 additions and 8 deletions

View File

@@ -237,3 +237,5 @@ with torch.no_grad():
```
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
This workflow is also compatible with LoRAs via [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. However, only LoRAs without text encoder components are currently supported in this workflow.

View File

@@ -31,6 +31,7 @@ from ..utils import (
delete_adapter_layers,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_transformers_available,
logging,
@@ -214,9 +215,18 @@ class LoraBaseMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:

View File

@@ -39,6 +39,7 @@ from ..utils import (
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_accelerate_version,
is_peft_version,
is_torch_version,
logging,
@@ -398,9 +399,18 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:

View File

@@ -36,6 +36,7 @@ from ..utils import (
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_transformers_available,
logging,
@@ -947,3 +948,9 @@ def _get_ignore_patterns(
)
return ignore_patterns
def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

View File

@@ -85,6 +85,7 @@ from .pipeline_loading_utils import (
_update_init_kwargs_with_connected_pipeline,
load_sub_model,
maybe_raise_or_warn,
model_has_device_map,
variant_compatible_siblings,
warn_deprecated_model_variant,
)
@@ -406,6 +407,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with explicitly setting the device using `to()`."
)
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
@@ -1002,6 +1013,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_model_cpu_offload()`."
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
@@ -1104,6 +1125,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_sequential_cpu_offload()`."
)
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:

View File

@@ -506,9 +506,14 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
@unittest.skip("Test currently not supported.")
def test_sequential_cpu_offload_forward_pass(self):
pass
@unittest.skip("Test currently not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):

View File

@@ -514,6 +514,18 @@ class StableDiffusionMultiControlNetPipelineFastTests(
assert image.shape == (4, 64, 64, 3)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
@@ -697,6 +709,18 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
except NotImplementedError:
pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow
@require_torch_gpu

View File

@@ -389,6 +389,18 @@ class StableDiffusionMultiControlNetPipelineFastTests(
except NotImplementedError:
pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow
@require_torch_gpu

View File

@@ -441,6 +441,18 @@ class MultiControlNetInpaintPipelineFastTests(
except NotImplementedError:
pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow
@require_torch_gpu

View File

@@ -683,6 +683,18 @@ class StableDiffusionXLMultiControlNetPipelineFastTests(
def test_save_load_optional_components(self):
return self._test_save_load_optional_components()
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
@@ -887,6 +899,18 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow
@require_torch_gpu

View File

@@ -8,9 +8,11 @@ from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_torch_multi_gpu,
slow,
torch_device,
)
@@ -282,3 +284,172 @@ class FluxPipelineSlowTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4
@require_torch_multi_gpu
@torch.no_grad()
def test_flux_component_sharding(self):
"""
internal note: test was run on `audace`.
"""
ckpt_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
prompt = "a photo of a cat with tiger-like look"
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
transformer=None,
vae=None,
device_map="balanced",
max_memory={0: "16GB", 1: "16GB"},
torch_dtype=dtype,
)
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.tokenizer
del pipeline.tokenizer_2
del pipeline
gc.collect()
torch.cuda.empty_cache()
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype
)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
transformer=transformer,
torch_dtype=dtype,
)
height, width = 768, 1360
# No need to wrap it up under `torch.no_grad()` as pipeline call method
# is already wrapped under that.
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=10,
guidance_scale=3.5,
height=height,
width=width,
output_type="latent",
generator=torch.manual_seed(0),
).images
latent_slice = latents[0, :3, :3].flatten().float().cpu().numpy()
expected_slice = np.array([-0.377, -0.3008, -0.5117, -0.252, 0.0615, -0.3477, -0.1309, -0.1914, 0.1533])
assert numpy_cosine_similarity_distance(latent_slice, expected_slice) < 1e-4
del pipeline.transformer
del pipeline
gc.collect()
torch.cuda.empty_cache()
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(torch_device)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="np")
image_slice = image[0, :3, :3, -1].flatten()
expected_slice = np.array([0.127, 0.1113, 0.1055, 0.1172, 0.1172, 0.1074, 0.1191, 0.1191, 0.1152])
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4
@require_torch_multi_gpu
@torch.no_grad()
def test_flux_component_sharding_with_lora(self):
"""
internal note: test was run on `audace`.
"""
ckpt_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
prompt = "jon snow eating pizza."
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
transformer=None,
vae=None,
device_map="balanced",
max_memory={0: "16GB", 1: "16GB"},
torch_dtype=dtype,
)
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.tokenizer
del pipeline.tokenizer_2
del pipeline
gc.collect()
torch.cuda.empty_cache()
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype
)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
transformer=transformer,
torch_dtype=dtype,
)
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
height, width = 768, 1360
# No need to wrap it up under `torch.no_grad()` as pipeline call method
# is already wrapped under that.
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=10,
guidance_scale=3.5,
height=height,
width=width,
output_type="latent",
generator=torch.manual_seed(0),
).images
latent_slice = latents[0, :3, :3].flatten().float().cpu().numpy()
expected_slice = np.array([-0.6523, -0.4961, -0.9141, -0.5, -0.2129, -0.6914, -0.375, -0.5664, -0.1699])
assert numpy_cosine_similarity_distance(latent_slice, expected_slice) < 1e-4
del pipeline.transformer
del pipeline
gc.collect()
torch.cuda.empty_cache()
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(torch_device)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="np")
image_slice = image[0, :3, :3, -1].flatten()
expected_slice = np.array([0.1211, 0.1094, 0.1035, 0.1094, 0.1113, 0.1074, 0.1133, 0.1133, 0.1094])
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4

View File

@@ -139,6 +139,18 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
def test_dict_tuple_outputs_equivalent(self):
super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyImg2ImgCombinedPipeline
@@ -248,6 +260,18 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyInpaintCombinedPipeline
@@ -363,3 +387,15 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
def test_save_load_local(self):
super().test_save_load_local(expected_max_difference=5e-3)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass

View File

@@ -159,6 +159,18 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
def test_callback_cfg(self):
pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22Img2ImgCombinedPipeline
@@ -281,6 +293,18 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
def test_callback_cfg(self):
pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22InpaintCombinedPipeline
@@ -404,3 +428,15 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest
def test_callback_cfg(self):
pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass

View File

@@ -404,6 +404,10 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
@unittest.skip("Test currently not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@nightly
@require_torch_gpu

View File

@@ -279,3 +279,15 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC
)
assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass

View File

@@ -593,6 +593,18 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow
@require_torch_gpu

View File

@@ -642,9 +642,6 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
print(",".join(debug))
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_adapter_sdxl_lcm_custom_timesteps(self):
@@ -667,7 +664,16 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
print(",".join(debug))
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass

View File

@@ -184,6 +184,18 @@ class StableUnCLIPPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@nightly
@require_torch_gpu

View File

@@ -205,6 +205,18 @@ class StableUnCLIPImg2ImgPipelineFastTests(
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(test_max_difference=False)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@nightly
@require_torch_gpu

View File

@@ -41,8 +41,11 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import (
CaptureLogger,
nightly,
require_torch,
require_torch_multi_gpu,
skip_mps,
slow,
torch_device,
)
@@ -59,6 +62,10 @@ from ..models.unets.test_models_unet_2d_condition import (
from ..others.test_utils import TOKEN, USER, is_staging_test
if is_accelerate_available():
from accelerate.utils import compute_module_sizes
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
@@ -1908,6 +1915,78 @@ class PipelineTesterMixin:
)
)
@require_torch_multi_gpu
@slow
@nightly
def test_calling_to_raises_error_device_mapped_components(self, safe_serialization=True):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
max_model_size = max(
compute_module_sizes(module)[""]
for _, module in pipe.components.items()
if isinstance(module, torch.nn.Module)
)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization)
max_memory = {0: max_model_size, 1: max_model_size}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory)
with self.assertRaises(ValueError) as err_context:
loaded_pipe.to(torch_device)
self.assertTrue(
"The following pipeline components have been found" in str(err_context.exception)
and "This is incompatible with explicitly setting the device using `to()`" in str(err_context.exception)
)
@require_torch_multi_gpu
@slow
@nightly
def test_calling_mco_raises_error_device_mapped_components(self, safe_serialization=True):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
max_model_size = max(
compute_module_sizes(module)[""]
for _, module in pipe.components.items()
if isinstance(module, torch.nn.Module)
)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization)
max_memory = {0: max_model_size, 1: max_model_size}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory)
with self.assertRaises(ValueError) as err_context:
loaded_pipe.enable_model_cpu_offload()
self.assertTrue(
"The following pipeline components have been found" in str(err_context.exception)
and "This is incompatible with `enable_model_cpu_offload()`" in str(err_context.exception)
)
@require_torch_multi_gpu
@slow
@nightly
def test_calling_sco_raises_error_device_mapped_components(self, safe_serialization=True):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
max_model_size = max(
compute_module_sizes(module)[""]
for _, module in pipe.components.items()
if isinstance(module, torch.nn.Module)
)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization)
max_memory = {0: max_model_size, 1: max_model_size}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory)
with self.assertRaises(ValueError) as err_context:
loaded_pipe.enable_sequential_cpu_offload()
self.assertTrue(
"The following pipeline components have been found" in str(err_context.exception)
and "This is incompatible with `enable_sequential_cpu_offload()`" in str(err_context.exception)
)
@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):

View File

@@ -576,6 +576,15 @@ class UniDiffuserPipelineFastTests(
expected_text_prefix = '" This This'
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
def test_calling_mco_raises_error_device_mapped_components(self):
super().test_calling_mco_raises_error_device_mapped_components(safe_serialization=False)
def test_calling_to_raises_error_device_mapped_components(self):
super().test_calling_to_raises_error_device_mapped_components(safe_serialization=False)
def test_calling_sco_raises_error_device_mapped_components(self):
super().test_calling_sco_raises_error_device_mapped_components(safe_serialization=False)
@nightly
@require_torch_gpu

View File

@@ -237,3 +237,15 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
def test_callback_cfg(self):
pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass