1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix(training_utils): wrap device in list for DiffusionPipeline (#12178)

- Modify offload_models function to handle DiffusionPipeline correctly
- Ensure compatibility with both single and multiple module inputs

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
MQY
2025-08-18 16:26:17 +08:00
committed by GitHub
parent e824660436
commit 9918d13eba

View File

@@ -339,7 +339,8 @@ def offload_models(
original_devices = [next(m.parameters()).device for m in modules]
else:
assert len(modules) == 1
original_devices = modules[0].device
# For DiffusionPipeline, wrap the device in a list to make it iterable
original_devices = [modules[0].device]
# move to target device
for m in modules:
m.to(device)