mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
||||
|
||||
|
||||
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
try:
|
||||
return next(parameter.parameters()).dtype
|
||||
except StopIteration:
|
||||
try:
|
||||
return next(parameter.buffers()).dtype
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
"""
|
||||
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
||||
"""
|
||||
last_dtype = None
|
||||
for param in parameter.parameters():
|
||||
last_dtype = param.dtype
|
||||
if param.is_floating_point():
|
||||
return param.dtype
|
||||
|
||||
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
for buffer in parameter.buffers():
|
||||
last_dtype = buffer.dtype
|
||||
if buffer.is_floating_point():
|
||||
return buffer.dtype
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
if last_dtype is not None:
|
||||
# if no floating dtype was found return whatever the first dtype is
|
||||
return last_dtype
|
||||
|
||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
last_tuple = None
|
||||
for tuple in gen:
|
||||
last_tuple = tuple
|
||||
if tuple[1].is_floating_point():
|
||||
return tuple[1].dtype
|
||||
|
||||
if last_tuple is not None:
|
||||
# fallback to the last dtype
|
||||
return last_tuple[1].dtype
|
||||
|
||||
|
||||
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
Reference in New Issue
Block a user