Add a main_input_name attribute to all models (#14803)

* Add a main_input_name attribute to all models

* Fix tests

* Wtf Vs Code?

* Update src/transformers/models/imagegpt/modeling_imagegpt.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Style

* Fix copies

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger 2021-12-20 11:19:08 -05:00 committed by GitHub
parent 0940e9b242
commit 33f36c869f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 61 additions and 5 deletions

View File

@ -76,9 +76,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
main_input_name = "input_ids"
def __init__( def __init__(
self, self,

View File

@ -653,9 +653,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
main_input_name = "input_ids"
# a list of re pattern of tensor names to ignore from the model when loading the model weights # a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings). # (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None _keys_to_ignore_on_load_missing = None

View File

@ -17,7 +17,6 @@
import inspect import inspect
import os import os
import re import re
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
@ -376,11 +375,10 @@ class ModuleUtilsMixin:
Returns: Returns:
:obj:`int`: The total number of tokens. :obj:`int`: The total number of tokens.
""" """
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key] if self.main_input_name in input_dict:
if token_inputs: return input_dict[self.main_input_name].numel()
return sum([token_input.numel() for token_input in token_inputs])
else: else:
warnings.warn( logger.warn(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed" "Could not estimate the number of tokens of the input, floating-point operations will not be computed"
) )
return 0 return 0
@ -438,9 +436,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization. - **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
main_input_name = "input_ids"
# a list of re pattern of tensor names to ignore from the model when loading the model weights # a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings). # (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None _keys_to_ignore_on_load_missing = None

View File

@ -523,6 +523,7 @@ class BeitPreTrainedModel(PreTrainedModel):
config_class = BeitConfig config_class = BeitConfig
base_model_prefix = "beit" base_model_prefix = "beit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):

View File

@ -590,6 +590,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
config_class = BeitConfig config_class = BeitConfig
base_model_prefix = "beit" base_model_prefix = "beit"
main_input_name = "pixel_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):

View File

@ -789,6 +789,7 @@ class CLIPVisionTransformer(nn.Module):
class CLIPVisionModel(CLIPPreTrainedModel): class CLIPVisionModel(CLIPPreTrainedModel):
config_class = CLIPVisionConfig config_class = CLIPVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: CLIPVisionConfig): def __init__(self, config: CLIPVisionConfig):
super().__init__(config) super().__init__(config)

View File

@ -653,6 +653,7 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel): class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPVisionConfig config_class = CLIPVisionConfig
main_input_name = "pixel_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__( def __init__(

View File

@ -385,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
config_class = DeiTConfig config_class = DeiTConfig
base_model_prefix = "deit" base_model_prefix = "deit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):

View File

@ -784,6 +784,7 @@ class DetrClassificationHead(nn.Module):
class DetrPreTrainedModel(PreTrainedModel): class DetrPreTrainedModel(PreTrainedModel):
config_class = DetrConfig config_class = DetrConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "pixel_values"
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std

View File

@ -776,6 +776,7 @@ class HubertPreTrainedModel(PreTrainedModel):
config_class = HubertConfig config_class = HubertConfig
base_model_prefix = "hubert" base_model_prefix = "hubert"
main_input_name = "input_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]

View File

@ -1265,6 +1265,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
config_class = HubertConfig config_class = HubertConfig
base_model_prefix = "hubert" base_model_prefix = "hubert"
main_input_name = "input_values"
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:

View File

@ -496,6 +496,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
config_class = ImageGPTConfig config_class = ImageGPTConfig
load_tf_weights = load_tf_weights_in_imagegpt load_tf_weights = load_tf_weights_in_imagegpt
base_model_prefix = "transformer" base_model_prefix = "transformer"
main_input_name = "input_ids"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):

View File

@ -619,6 +619,7 @@ class PerceiverPreTrainedModel(PreTrainedModel):
config_class = PerceiverConfig config_class = PerceiverConfig
base_model_prefix = "perceiver" base_model_prefix = "perceiver"
main_input_name = "inputs"
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@ -406,6 +406,7 @@ class SegformerPreTrainedModel(PreTrainedModel):
config_class = SegformerConfig config_class = SegformerConfig
base_model_prefix = "segformer" base_model_prefix = "segformer"
main_input_name = "pixel_values"
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@ -675,6 +675,7 @@ class SEWPreTrainedModel(PreTrainedModel):
config_class = SEWConfig config_class = SEWConfig
base_model_prefix = "sew" base_model_prefix = "sew"
main_input_name = "input_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]

View File

@ -1201,6 +1201,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
config_class = SEWDConfig config_class = SEWDConfig
base_model_prefix = "sew-d" base_model_prefix = "sew-d"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True

View File

@ -180,6 +180,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
""" """
config_class = SpeechEncoderDecoderConfig config_class = SpeechEncoderDecoderConfig
base_model_prefix = "speech_encoder_decoder" base_model_prefix = "speech_encoder_decoder"
main_input_name = "input_values"
def __init__( def __init__(
self, self,

View File

@ -539,6 +539,7 @@ class Speech2TextDecoderLayer(nn.Module):
class Speech2TextPreTrainedModel(PreTrainedModel): class Speech2TextPreTrainedModel(PreTrainedModel):
config_class = Speech2TextConfig config_class = Speech2TextConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):

View File

@ -912,6 +912,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
config_class = UniSpeechConfig config_class = UniSpeechConfig
base_model_prefix = "unispeech" base_model_prefix = "unispeech"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True

View File

@ -947,6 +947,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
config_class = UniSpeechSatConfig config_class = UniSpeechSatConfig
base_model_prefix = "unispeech_sat" base_model_prefix = "unispeech_sat"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True

View File

@ -283,6 +283,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
""" """
config_class = VisionEncoderDecoderConfig config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder" base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
module_class = FlaxVisionEncoderDecoderModule module_class = FlaxVisionEncoderDecoderModule
def __init__( def __init__(

View File

@ -160,6 +160,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
""" """
config_class = VisionEncoderDecoderConfig config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder" base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
def __init__( def __init__(
self, self,

View File

@ -406,6 +406,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
config_class = ViTConfig config_class = ViTConfig
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):

View File

@ -555,6 +555,7 @@ class TFViTPreTrainedModel(TFPreTrainedModel):
config_class = ViTConfig config_class = ViTConfig
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values"
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:

View File

@ -412,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel):
config_class = ViTConfig config_class = ViTConfig
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):

View File

@ -775,6 +775,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
config_class = Wav2Vec2Config config_class = Wav2Vec2Config
base_model_prefix: str = "wav2vec2" base_model_prefix: str = "wav2vec2"
main_input_name = "input_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__( def __init__(

View File

@ -1256,6 +1256,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
config_class = Wav2Vec2Config config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2" base_model_prefix = "wav2vec2"
main_input_name = "input_values"
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:

View File

@ -1044,6 +1044,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
config_class = Wav2Vec2Config config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2" base_model_prefix = "wav2vec2"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True

View File

@ -996,6 +996,7 @@ class WavLMPreTrainedModel(PreTrainedModel):
config_class = WavLMConfig config_class = WavLMConfig
base_model_prefix = "wavlm" base_model_prefix = "wavlm"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True

View File

@ -1315,6 +1315,13 @@ class ModelTesterMixin:
x = model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear)) self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def test_correct_missing_keys(self): def test_correct_missing_keys(self):
if not self.test_missing_keys: if not self.test_missing_keys:
return return

View File

@ -778,6 +778,13 @@ class FlaxModelTesterMixin:
for name, type_ in types.items(): for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.") self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "__call__"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def test_headmasking(self): def test_headmasking(self):
if not self.test_head_masking: if not self.test_head_masking:
return return

View File

@ -1183,6 +1183,13 @@ class TFModelTesterMixin:
else: else:
new_model_without_prefix(input_ids) new_model_without_prefix(input_ids)
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "call"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def _generate_random_bad_tokens(self, num_bad_tokens, model): def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens # special tokens cannot be bad tokens
special_tokens = [] special_tokens = []