mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Revert "search buffers for dtype" (#23308)
Revert "search buffers for dtype (#23159)"
This reverts commit ef42c2c487
.
This commit is contained in:
parent
ba71d9e94c
commit
273f5ba026
@ -207,29 +207,21 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
|
||||
# if no floating dtype was found return whatever the first dtype is
|
||||
return last_dtype
|
||||
|
||||
for t in parameter.buffers():
|
||||
last_dtype = t.dtype
|
||||
if t.is_floating_point():
|
||||
return t.dtype
|
||||
else:
|
||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
if last_dtype is not None:
|
||||
# if no floating dtype was found return whatever the first dtype is
|
||||
return last_dtype
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
last_tuple = None
|
||||
for tuple in gen:
|
||||
last_tuple = tuple
|
||||
if tuple[1].is_floating_point():
|
||||
return tuple[1].dtype
|
||||
|
||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
last_tuple = None
|
||||
for tuple in gen:
|
||||
last_tuple = tuple
|
||||
if tuple[1].is_floating_point():
|
||||
return tuple[1].dtype
|
||||
|
||||
# fallback to the last dtype
|
||||
return last_tuple[1].dtype
|
||||
# fallback to the last dtype
|
||||
return last_tuple[1].dtype
|
||||
|
||||
|
||||
def get_state_dict_float_dtype(state_dict):
|
||||
|
Loading…
Reference in New Issue
Block a user