diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 56e8254c8c..5a8227f5ce 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -143,9 +143,9 @@ class ModularPipelineTesterMixin: def _check_for_parameters(parameters, expected_parameters, param_type): remaining_parameters = {param for param in parameters if param not in expected_parameters} - assert len(remaining_parameters) == 0, ( - f"Required {param_type} parameters not present: {remaining_parameters}" - ) + assert ( + len(remaining_parameters) == 0 + ), f"Required {param_type} parameters not present: {remaining_parameters}" _check_for_parameters(self.params, input_parameters, "input") _check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate") @@ -274,9 +274,9 @@ class ModularPipelineTesterMixin: model_devices = [ component.device.type for component in pipe.components.values() if hasattr(component, "device") ] - assert all(device == torch_device for device in model_devices), ( - "All pipeline components are not on accelerator device" - ) + assert all( + device == torch_device for device in model_devices + ), "All pipeline components are not on accelerator device" def test_inference_is_not_nan_cpu(self): pipe = self.get_pipeline() @@ -318,3 +318,13 @@ class ModularPipelineTesterMixin: images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images") assert images.shape[0] == batch_size * num_images_per_prompt + + @require_accelerator + def test_components_auto_cpu_offload(self): + base_pipe = self.get_pipeline().to(torch_device) + for component in base_pipe.components: + assert component.device == torch_device + + cm = ComponentsManager() + cm.enable_auto_cpu_offload(device=torch_device) + offload_pipe = self.get_pipeline(components_manager=cm)