Add a main_input_name attribute to all models (#14803)

* Add a main_input_name attribute to all models

* Fix tests

* Wtf Vs Code?

* Update src/transformers/models/imagegpt/modeling_imagegpt.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Style

* Fix copies

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger 2021-12-20 11:19:08 -05:00 committed by GitHub
parent 0940e9b242
commit 33f36c869f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 61 additions and 5 deletions

View File

@ -76,9 +76,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
def __init__(
self,

View File

@ -653,9 +653,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None

View File

@ -17,7 +17,6 @@
import inspect
import os
import re
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
@ -376,11 +375,10 @@ class ModuleUtilsMixin:
Returns:
:obj:`int`: The total number of tokens.
"""
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
if token_inputs:
return sum([token_input.numel() for token_input in token_inputs])
if self.main_input_name in input_dict:
return input_dict[self.main_input_name].numel()
else:
warnings.warn(
logger.warn(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
return 0
@ -438,9 +436,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None

View File

@ -523,6 +523,7 @@ class BeitPreTrainedModel(PreTrainedModel):
config_class = BeitConfig
base_model_prefix = "beit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):

View File

@ -590,6 +590,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
config_class = BeitConfig
base_model_prefix = "beit"
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):

View File

@ -789,6 +789,7 @@ class CLIPVisionTransformer(nn.Module):
class CLIPVisionModel(CLIPPreTrainedModel):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)

View File

@ -653,6 +653,7 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(

View File

@ -385,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
config_class = DeiTConfig
base_model_prefix = "deit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):

View File

@ -784,6 +784,7 @@ class DetrClassificationHead(nn.Module):
class DetrPreTrainedModel(PreTrainedModel):
config_class = DetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
def _init_weights(self, module):
std = self.config.init_std

View File

@ -776,6 +776,7 @@ class HubertPreTrainedModel(PreTrainedModel):
config_class = HubertConfig
base_model_prefix = "hubert"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

View File

@ -1265,6 +1265,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
config_class = HubertConfig
base_model_prefix = "hubert"
main_input_name = "input_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:

View File

@ -496,6 +496,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
config_class = ImageGPTConfig
load_tf_weights = load_tf_weights_in_imagegpt
base_model_prefix = "transformer"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs):

View File

@ -619,6 +619,7 @@ class PerceiverPreTrainedModel(PreTrainedModel):
config_class = PerceiverConfig
base_model_prefix = "perceiver"
main_input_name = "inputs"
def _init_weights(self, module):
"""Initialize the weights"""

View File

@ -406,6 +406,7 @@ class SegformerPreTrainedModel(PreTrainedModel):
config_class = SegformerConfig
base_model_prefix = "segformer"
main_input_name = "pixel_values"
def _init_weights(self, module):
"""Initialize the weights"""

View File

@ -675,6 +675,7 @@ class SEWPreTrainedModel(PreTrainedModel):
config_class = SEWConfig
base_model_prefix = "sew"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

View File

@ -1201,6 +1201,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
config_class = SEWDConfig
base_model_prefix = "sew-d"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True

View File

@ -180,6 +180,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
"""
config_class = SpeechEncoderDecoderConfig
base_model_prefix = "speech_encoder_decoder"
main_input_name = "input_values"
def __init__(
self,

View File

@ -539,6 +539,7 @@ class Speech2TextDecoderLayer(nn.Module):
class Speech2TextPreTrainedModel(PreTrainedModel):
config_class = Speech2TextConfig
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
def _init_weights(self, module):

View File

@ -912,6 +912,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
config_class = UniSpeechConfig
base_model_prefix = "unispeech"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True

View File

@ -947,6 +947,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
config_class = UniSpeechSatConfig
base_model_prefix = "unispeech_sat"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True

View File

@ -283,6 +283,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
"""
config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
module_class = FlaxVisionEncoderDecoderModule
def __init__(

View File

@ -160,6 +160,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
"""
config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
def __init__(
self,

View File

@ -406,6 +406,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
config_class = ViTConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):

View File

@ -555,6 +555,7 @@ class TFViTPreTrainedModel(TFPreTrainedModel):
config_class = ViTConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:

View File

@ -412,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel):
config_class = ViTConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):

View File

@ -775,6 +775,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
config_class = Wav2Vec2Config
base_model_prefix: str = "wav2vec2"
main_input_name = "input_values"
module_class: nn.Module = None
def __init__(

View File

@ -1256,6 +1256,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2"
main_input_name = "input_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:

View File

@ -1044,6 +1044,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True

View File

@ -996,6 +996,7 @@ class WavLMPreTrainedModel(PreTrainedModel):
config_class = WavLMConfig
base_model_prefix = "wavlm"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True

View File

@ -1315,6 +1315,13 @@ class ModelTesterMixin:
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def test_correct_missing_keys(self):
if not self.test_missing_keys:
return

View File

@ -778,6 +778,13 @@ class FlaxModelTesterMixin:
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "__call__"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def test_headmasking(self):
if not self.test_head_masking:
return

View File

@ -1183,6 +1183,13 @@ class TFModelTesterMixin:
else:
new_model_without_prefix(input_ids)
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "call"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []