1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Test] Reduce CPU memory (#4897)

* [Test] Reduce CPU memory

* [Test] Reduce CPU memory
This commit is contained in:
Patrick von Platen
2023-09-05 09:48:35 +02:00
committed by GitHub
parent cfdfcf2018
commit 2340ed629e

View File

@@ -107,7 +107,7 @@ def state_dicts_almost_equal(sd1, sd2):
models_are_equal = True
for ten1, ten2 in zip(sd1.values(), sd2.values()):
if (ten1 - ten2).abs().sum() > 1e-3:
if (ten1 - ten2).abs().max() > 1e-3:
models_are_equal = False
return models_are_equal
@@ -1432,23 +1432,21 @@ class LoraIntegrationTests(unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_sdxl_1_0_fuse_unfuse_all(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
unet_sd = copy.deepcopy(pipe.unet.state_dict())
pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors")
pipe.load_lora_weights(
"davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors", torch_dtype=torch.float16
)
pipe.fuse_lora()
pipe.unload_lora_weights()
pipe.unfuse_lora()
new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
new_unet_sd = copy.deepcopy(pipe.unet.state_dict())
assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)
assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)
assert state_dicts_almost_equal(unet_sd, new_unet_sd)
assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict())
assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict())
assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict())
def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self):
generator = torch.Generator().manual_seed(0)