mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix(training_utils): wrap device in list for DiffusionPipeline (#12178)
- Modify offload_models function to handle DiffusionPipeline correctly - Ensure compatibility with both single and multiple module inputs Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -339,7 +339,8 @@ def offload_models(
|
||||
original_devices = [next(m.parameters()).device for m in modules]
|
||||
else:
|
||||
assert len(modules) == 1
|
||||
original_devices = modules[0].device
|
||||
# For DiffusionPipeline, wrap the device in a list to make it iterable
|
||||
original_devices = [modules[0].device]
|
||||
# move to target device
|
||||
for m in modules:
|
||||
m.to(device)
|
||||
|
||||
Reference in New Issue
Block a user