mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[gguf] Refactor __torch_function__ to avoid unnecessary computation (#11551)
* [gguf] Refactor __torch_function__ to avoid unnecessary computation This helps with torch.compile compilation lantency. Avoiding unnecessary computation should also lead to a slightly improved eager latency. * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -408,6 +408,18 @@ class GGUFParameter(torch.nn.Parameter):
|
||||
def as_tensor(self):
|
||||
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
|
||||
|
||||
@staticmethod
|
||||
def _extract_quant_type(args):
|
||||
# When converting from original format checkpoints we often use splits, cats etc on tensors
|
||||
# this method ensures that the returned tensor type from those operations remains GGUFParameter
|
||||
# so that we preserve quant_type information
|
||||
for arg in args:
|
||||
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
|
||||
return arg[0].quant_type
|
||||
if isinstance(arg, GGUFParameter):
|
||||
return arg.quant_type
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
@@ -415,22 +427,13 @@ class GGUFParameter(torch.nn.Parameter):
|
||||
|
||||
result = super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
# When converting from original format checkpoints we often use splits, cats etc on tensors
|
||||
# this method ensures that the returned tensor type from those operations remains GGUFParameter
|
||||
# so that we preserve quant_type information
|
||||
quant_type = None
|
||||
for arg in args:
|
||||
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
|
||||
quant_type = arg[0].quant_type
|
||||
break
|
||||
if isinstance(arg, GGUFParameter):
|
||||
quant_type = arg.quant_type
|
||||
break
|
||||
if isinstance(result, torch.Tensor):
|
||||
quant_type = cls._extract_quant_type(args)
|
||||
return cls(result, quant_type=quant_type)
|
||||
# Handle tuples and lists
|
||||
elif isinstance(result, (tuple, list)):
|
||||
elif type(result) in (list, tuple):
|
||||
# Preserve the original type (tuple or list)
|
||||
quant_type = cls._extract_quant_type(args)
|
||||
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
|
||||
return type(result)(wrapped)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user