mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Test] Reduce CPU memory (#4897)
* [Test] Reduce CPU memory * [Test] Reduce CPU memory
This commit is contained in:
committed by
GitHub
parent
cfdfcf2018
commit
2340ed629e
@@ -107,7 +107,7 @@ def state_dicts_almost_equal(sd1, sd2):
|
||||
|
||||
models_are_equal = True
|
||||
for ten1, ten2 in zip(sd1.values(), sd2.values()):
|
||||
if (ten1 - ten2).abs().sum() > 1e-3:
|
||||
if (ten1 - ten2).abs().max() > 1e-3:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
@@ -1432,23 +1432,21 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
def test_sdxl_1_0_fuse_unfuse_all(self):
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
|
||||
text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
|
||||
unet_sd = copy.deepcopy(pipe.unet.state_dict())
|
||||
|
||||
pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors")
|
||||
pipe.load_lora_weights(
|
||||
"davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.fuse_lora()
|
||||
pipe.unload_lora_weights()
|
||||
pipe.unfuse_lora()
|
||||
|
||||
new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
|
||||
new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
|
||||
new_unet_sd = copy.deepcopy(pipe.unet.state_dict())
|
||||
|
||||
assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)
|
||||
assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)
|
||||
assert state_dicts_almost_equal(unet_sd, new_unet_sd)
|
||||
assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict())
|
||||
assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict())
|
||||
assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict())
|
||||
|
||||
def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user