mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fixing bug with param count without embeddings (#12461)
* fixing bug with param count without embeddings * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
d5b8fe3b90
commit
7f0027db30
@ -352,11 +352,16 @@ class ModuleUtilsMixin:
|
||||
:obj:`int`: The number of parameters.
|
||||
"""
|
||||
|
||||
def parameter_filter(x):
|
||||
return (x.requires_grad or not only_trainable) and not (isinstance(x, nn.Embedding) and exclude_embeddings)
|
||||
|
||||
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
if exclude_embeddings:
|
||||
embedding_param_names = [
|
||||
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
|
||||
]
|
||||
non_embedding_parameters = [
|
||||
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
||||
]
|
||||
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
||||
else:
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
||||
|
||||
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user