diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index b2733b7c0a..90ff834ed9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1144,20 +1144,24 @@ class PipelineTesterMixin: self.assertLess( max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results" ) - offloaded_modules = [ - v + offloaded_modules = { + k: v for k, v in pipe.components.items() if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload - ] - ( - self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)), - f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}", + } + self.assertTrue( + all(v.device.type == "cpu" for v in offloaded_modules.values()), + f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}", ) - offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")] - ( - self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)), - f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}", + offloaded_modules_with_incorrect_hooks = {} + for k, v in offloaded_modules.items(): + if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) + + self.assertTrue( + len(offloaded_modules_with_incorrect_hooks) == 0, + f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", ) @unittest.skipIf( @@ -1189,22 +1193,23 @@ class PipelineTesterMixin: self.assertLess( max_diff, expected_max_diff, "running sequential offloading second time should have the inference results" ) - offloaded_modules = [ - v + offloaded_modules = { + k: v for k, v in pipe.components.items() if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload - ] - ( - self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)), - f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}", + } + self.assertTrue( + all(v.device.type == "meta" for v in offloaded_modules.values()), + f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}", ) + offloaded_modules_with_incorrect_hooks = {} + for k, v in offloaded_modules.items(): + if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) - offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")] - ( - self.assertTrue( - all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks) - ), - f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}", + self.assertTrue( + len(offloaded_modules_with_incorrect_hooks) == 0, + f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", ) @unittest.skipIf(