Factor out methods (#10215)

This commit is contained in:
Lysandre Debut 2021-02-17 15:53:43 +01:00 committed by GitHub
parent e94d63f6cb
commit 4b91965731
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -86,6 +86,36 @@ def find_pruneable_heads_and_indices(
return heads, index
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
class ModuleUtilsMixin:
"""
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
@ -145,36 +175,14 @@ class ModuleUtilsMixin:
:obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
try:
return next(self.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
return get_parameter_device(self)
@property
def dtype(self) -> dtype:
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
return get_parameter_dtype(self)
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""
@ -1238,7 +1246,7 @@ class PoolerStartLogits(nn.Module):
x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None:
if next(self.parameters()).dtype == torch.float16:
if get_parameter_dtype(self) == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
@ -1305,7 +1313,7 @@ class PoolerEndLogits(nn.Module):
x = self.dense_1(x).squeeze(-1)
if p_mask is not None:
if next(self.parameters()).dtype == torch.float16:
if get_parameter_dtype(self) == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask