mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
num_parameters helper
This commit is contained in:
parent
331065e62d
commit
84c0aa1868
@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
|
||||
@ -31,7 +32,22 @@ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TFPreTrainedModel(tf.keras.Model):
|
||||
class TFModelUtils:
|
||||
"""
|
||||
A few utilities for `tf.keras.Model`s, to be used as a mixin.
|
||||
"""
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False) -> int:
|
||||
"""
|
||||
Get number of (optionally, trainable) parameters in the model.
|
||||
"""
|
||||
if only_trainable:
|
||||
return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
|
||||
else:
|
||||
return self.count_params()
|
||||
|
||||
|
||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtils):
|
||||
r""" Base class for all TF models.
|
||||
|
||||
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||
|
@ -53,7 +53,20 @@ except ImportError:
|
||||
return input
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
class ModuleUtils:
|
||||
"""
|
||||
A few utilities for torch.nn.Modules, to be used as a mixin.
|
||||
"""
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False) -> int:
|
||||
"""
|
||||
Get number of (optionally, trainable) parameters in the module.
|
||||
"""
|
||||
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtils):
|
||||
r""" Base class for all models.
|
||||
|
||||
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||
|
@ -100,3 +100,5 @@ class AutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
|
@ -99,3 +99,5 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
|
Loading…
Reference in New Issue
Block a user