mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Factor out methods (#10215)
This commit is contained in:
parent
e94d63f6cb
commit
4b91965731
@ -86,6 +86,36 @@ def find_pruneable_heads_and_indices(
|
||||
return heads, index
|
||||
|
||||
|
||||
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].device
|
||||
|
||||
|
||||
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
||||
try:
|
||||
return next(parameter.parameters()).dtype
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
|
||||
|
||||
class ModuleUtilsMixin:
|
||||
"""
|
||||
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
|
||||
@ -145,36 +175,14 @@ class ModuleUtilsMixin:
|
||||
:obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||
device).
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).device
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].device
|
||||
return get_parameter_device(self)
|
||||
|
||||
@property
|
||||
def dtype(self) -> dtype:
|
||||
"""
|
||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).dtype
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
return get_parameter_dtype(self)
|
||||
|
||||
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
||||
"""
|
||||
@ -1238,7 +1246,7 @@ class PoolerStartLogits(nn.Module):
|
||||
x = self.dense(hidden_states).squeeze(-1)
|
||||
|
||||
if p_mask is not None:
|
||||
if next(self.parameters()).dtype == torch.float16:
|
||||
if get_parameter_dtype(self) == torch.float16:
|
||||
x = x * (1 - p_mask) - 65500 * p_mask
|
||||
else:
|
||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||
@ -1305,7 +1313,7 @@ class PoolerEndLogits(nn.Module):
|
||||
x = self.dense_1(x).squeeze(-1)
|
||||
|
||||
if p_mask is not None:
|
||||
if next(self.parameters()).dtype == torch.float16:
|
||||
if get_parameter_dtype(self) == torch.float16:
|
||||
x = x * (1 - p_mask) - 65500 * p_mask
|
||||
else:
|
||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||
|
Loading…
Reference in New Issue
Block a user