mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -143,9 +143,9 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
remaining_parameters = {param for param in parameters if param not in expected_parameters}
|
||||
assert len(remaining_parameters) == 0, (
|
||||
f"Required {param_type} parameters not present: {remaining_parameters}"
|
||||
)
|
||||
assert (
|
||||
len(remaining_parameters) == 0
|
||||
), f"Required {param_type} parameters not present: {remaining_parameters}"
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate")
|
||||
@@ -274,9 +274,9 @@ class ModularPipelineTesterMixin:
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == torch_device for device in model_devices), (
|
||||
"All pipeline components are not on accelerator device"
|
||||
)
|
||||
assert all(
|
||||
device == torch_device for device in model_devices
|
||||
), "All pipeline components are not on accelerator device"
|
||||
|
||||
def test_inference_is_not_nan_cpu(self):
|
||||
pipe = self.get_pipeline()
|
||||
@@ -318,3 +318,13 @@ class ModularPipelineTesterMixin:
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
@require_accelerator
|
||||
def test_components_auto_cpu_offload(self):
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
for component in base_pipe.components:
|
||||
assert component.device == torch_device
|
||||
|
||||
cm = ComponentsManager()
|
||||
cm.enable_auto_cpu_offload(device=torch_device)
|
||||
offload_pipe = self.get_pipeline(components_manager=cm)
|
||||
|
||||
Reference in New Issue
Block a user