num_parameters helper

This commit is contained in:
Julien Chaumond 2020-01-10 17:40:02 +00:00
parent 331065e62d
commit 84c0aa1868
4 changed files with 35 additions and 2 deletions

View File

@ -20,6 +20,7 @@ import logging
import os
import h5py
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.saving import hdf5_format
@ -31,7 +32,22 @@ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__)
class TFPreTrainedModel(tf.keras.Model):
class TFModelUtils:
"""
A few utilities for `tf.keras.Model`s, to be used as a mixin.
"""
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get number of (optionally, trainable) parameters in the model.
"""
if only_trainable:
return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
else:
return self.count_params()
class TFPreTrainedModel(tf.keras.Model, TFModelUtils):
r""" Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models

View File

@ -53,7 +53,20 @@ except ImportError:
return input
class PreTrainedModel(nn.Module):
class ModuleUtils:
"""
A few utilities for torch.nn.Modules, to be used as a mixin.
"""
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get number of (optionally, trainable) parameters in the module.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
class PreTrainedModel(nn.Module, ModuleUtils):
r""" Base class for all models.
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models

View File

@ -100,3 +100,5 @@ class AutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)

View File

@ -99,3 +99,5 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, TFBertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)