mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
num_parameters helper
This commit is contained in:
parent
331065e62d
commit
84c0aa1868
@ -20,6 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.keras.saving import hdf5_format
|
from tensorflow.python.keras.saving import hdf5_format
|
||||||
|
|
||||||
@ -31,7 +32,22 @@ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TFPreTrainedModel(tf.keras.Model):
|
class TFModelUtils:
|
||||||
|
"""
|
||||||
|
A few utilities for `tf.keras.Model`s, to be used as a mixin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def num_parameters(self, only_trainable: bool = False) -> int:
|
||||||
|
"""
|
||||||
|
Get number of (optionally, trainable) parameters in the model.
|
||||||
|
"""
|
||||||
|
if only_trainable:
|
||||||
|
return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
|
||||||
|
else:
|
||||||
|
return self.count_params()
|
||||||
|
|
||||||
|
|
||||||
|
class TFPreTrainedModel(tf.keras.Model, TFModelUtils):
|
||||||
r""" Base class for all TF models.
|
r""" Base class for all TF models.
|
||||||
|
|
||||||
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||||
|
@ -53,7 +53,20 @@ except ImportError:
|
|||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedModel(nn.Module):
|
class ModuleUtils:
|
||||||
|
"""
|
||||||
|
A few utilities for torch.nn.Modules, to be used as a mixin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def num_parameters(self, only_trainable: bool = False) -> int:
|
||||||
|
"""
|
||||||
|
Get number of (optionally, trainable) parameters in the module.
|
||||||
|
"""
|
||||||
|
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
||||||
|
return sum(p.numel() for p in params)
|
||||||
|
|
||||||
|
|
||||||
|
class PreTrainedModel(nn.Module, ModuleUtils):
|
||||||
r""" Base class for all models.
|
r""" Base class for all models.
|
||||||
|
|
||||||
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||||
|
@ -100,3 +100,5 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||||
self.assertIsInstance(model, BertForMaskedLM)
|
self.assertIsInstance(model, BertForMaskedLM)
|
||||||
|
self.assertEqual(model.num_parameters(), 14830)
|
||||||
|
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||||
|
@ -99,3 +99,5 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||||
|
self.assertEqual(model.num_parameters(), 14830)
|
||||||
|
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||||
|
Loading…
Reference in New Issue
Block a user