From 2340ed629ebad4a0e7ed5c2aaee975c8279432f8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 5 Sep 2023 09:48:35 +0200 Subject: [PATCH] [Test] Reduce CPU memory (#4897) * [Test] Reduce CPU memory * [Test] Reduce CPU memory --- tests/models/test_lora_layers.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index f5aa95b657..c49ea7f2d9 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -107,7 +107,7 @@ def state_dicts_almost_equal(sd1, sd2): models_are_equal = True for ten1, ten2 in zip(sd1.values(), sd2.values()): - if (ten1 - ten2).abs().sum() > 1e-3: + if (ten1 - ten2).abs().max() > 1e-3: models_are_equal = False return models_are_equal @@ -1432,23 +1432,21 @@ class LoraIntegrationTests(unittest.TestCase): self.assertTrue(np.allclose(images, expected, atol=1e-3)) def test_sdxl_1_0_fuse_unfuse_all(self): - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) unet_sd = copy.deepcopy(pipe.unet.state_dict()) - pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors") + pipe.load_lora_weights( + "davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors", torch_dtype=torch.float16 + ) pipe.fuse_lora() pipe.unload_lora_weights() pipe.unfuse_lora() - new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) - new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) - new_unet_sd = copy.deepcopy(pipe.unet.state_dict()) - - assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd) - assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd) - assert state_dicts_almost_equal(unet_sd, new_unet_sd) + assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict()) + assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict()) + assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict()) def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): generator = torch.Generator().manual_seed(0)