1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[gguf] Refactor __torch_function__ to avoid unnecessary computation (#11551)

* [gguf] Refactor __torch_function__ to avoid unnecessary computation

This helps with torch.compile compilation lantency. Avoiding unnecessary
computation should also lead to a slightly improved eager latency.

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Animesh Jain
2025-05-15 02:08:18 -07:00
committed by GitHub
parent 4267d8f4eb
commit 3a6caba8e4

View File

@@ -408,6 +408,18 @@ class GGUFParameter(torch.nn.Parameter):
def as_tensor(self):
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
@staticmethod
def _extract_quant_type(args):
# When converting from original format checkpoints we often use splits, cats etc on tensors
# this method ensures that the returned tensor type from those operations remains GGUFParameter
# so that we preserve quant_type information
for arg in args:
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
return arg[0].quant_type
if isinstance(arg, GGUFParameter):
return arg.quant_type
return None
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
@@ -415,22 +427,13 @@ class GGUFParameter(torch.nn.Parameter):
result = super().__torch_function__(func, types, args, kwargs)
# When converting from original format checkpoints we often use splits, cats etc on tensors
# this method ensures that the returned tensor type from those operations remains GGUFParameter
# so that we preserve quant_type information
quant_type = None
for arg in args:
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
quant_type = arg[0].quant_type
break
if isinstance(arg, GGUFParameter):
quant_type = arg.quant_type
break
if isinstance(result, torch.Tensor):
quant_type = cls._extract_quant_type(args)
return cls(result, quant_type=quant_type)
# Handle tuples and lists
elif isinstance(result, (tuple, list)):
elif type(result) in (list, tuple):
# Preserve the original type (tuple or list)
quant_type = cls._extract_quant_type(args)
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
return type(result)(wrapped)
else: