diff --git a/nodes.py b/nodes.py index 37e7ea4..b25b7c9 100644 --- a/nodes.py +++ b/nodes.py @@ -42,18 +42,22 @@ def offload_transformer(transformer): transformer.teacache_state.clear_all() transformer.magcache_state.clear_all() transformer.easycache_state.clear_all() - #transformer.to(offload_device) - for name, param in transformer.named_parameters(): - module = transformer - subnames = name.split('.') - for subname in subnames[:-1]: - module = getattr(module, subname) - attr_name = subnames[-1] - if param.data.is_floating_point(): - meta_param = torch.nn.Parameter(torch.empty_like(param.data, device='meta'), requires_grad=False) - setattr(module, attr_name, meta_param) - else: - pass + + if transformer.patched_linear: + for name, param in transformer.named_parameters(): + module = transformer + subnames = name.split('.') + for subname in subnames[:-1]: + module = getattr(module, subname) + attr_name = subnames[-1] + if param.data.is_floating_point(): + meta_param = torch.nn.Parameter(torch.empty_like(param.data, device='meta'), requires_grad=False) + setattr(module, attr_name, meta_param) + else: + pass + else: + transformer.to(offload_device) + mm.soft_empty_cache() gc.collect()