Fix dtype getter (#17668)

* Fix dtype getters

* Proper fix for dtype getter

* Style and commant

* Always use last for consistency

* Quality
This commit is contained in:
Sylvain Gugger 2022-06-13 09:34:45 -04:00 committed by GitHub
parent 73083581a4
commit a1344dbfb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -139,7 +139,7 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu
try:
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
@ -152,31 +152,33 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the first dtype it found.
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
try:
for t in parameter.parameters():
if t.is_floating_point():
return t.dtype
last_dtype = None
for t in parameter.parameters():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
else:
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
return last_dtype
else:
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
# fallback to any dtype the model has even if not floating
else:
first_tuple = next(gen)
return first_tuple[1].dtype
# fallback to the last dtype
return last_tuple[1].dtype
def get_state_dict_float_dtype(state_dict):