From 2a3ca69c83144192d1888853cec2df1a22477342 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 26 Aug 2025 22:34:12 +0300 Subject: [PATCH] Fix offloading when using merged loras --- nodes.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/nodes.py b/nodes.py index 37e7ea4..b25b7c9 100644 --- a/nodes.py +++ b/nodes.py @@ -42,18 +42,22 @@ def offload_transformer(transformer): transformer.teacache_state.clear_all() transformer.magcache_state.clear_all() transformer.easycache_state.clear_all() - #transformer.to(offload_device) - for name, param in transformer.named_parameters(): - module = transformer - subnames = name.split('.') - for subname in subnames[:-1]: - module = getattr(module, subname) - attr_name = subnames[-1] - if param.data.is_floating_point(): - meta_param = torch.nn.Parameter(torch.empty_like(param.data, device='meta'), requires_grad=False) - setattr(module, attr_name, meta_param) - else: - pass + + if transformer.patched_linear: + for name, param in transformer.named_parameters(): + module = transformer + subnames = name.split('.') + for subname in subnames[:-1]: + module = getattr(module, subname) + attr_name = subnames[-1] + if param.data.is_floating_point(): + meta_param = torch.nn.Parameter(torch.empty_like(param.data, device='meta'), requires_grad=False) + setattr(module, attr_name, meta_param) + else: + pass + else: + transformer.to(offload_device) + mm.soft_empty_cache() gc.collect()