mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add a main_input_name attribute to all models (#14803)
* Add a main_input_name attribute to all models * Fix tests * Wtf Vs Code? * Update src/transformers/models/imagegpt/modeling_imagegpt.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Style * Fix copies Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
0940e9b242
commit
33f36c869f
@ -76,9 +76,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
|
||||
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
main_input_name = "input_ids"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -653,9 +653,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
|
||||
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
main_input_name = "input_ids"
|
||||
|
||||
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
_keys_to_ignore_on_load_missing = None
|
||||
|
@ -17,7 +17,6 @@
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
@ -376,11 +375,10 @@ class ModuleUtilsMixin:
|
||||
Returns:
|
||||
:obj:`int`: The total number of tokens.
|
||||
"""
|
||||
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
|
||||
if token_inputs:
|
||||
return sum([token_input.numel() for token_input in token_inputs])
|
||||
if self.main_input_name in input_dict:
|
||||
return input_dict[self.main_input_name].numel()
|
||||
else:
|
||||
warnings.warn(
|
||||
logger.warn(
|
||||
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
|
||||
)
|
||||
return 0
|
||||
@ -438,9 +436,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
|
||||
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
|
||||
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
main_input_name = "input_ids"
|
||||
|
||||
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
_keys_to_ignore_on_load_missing = None
|
||||
|
@ -523,6 +523,7 @@ class BeitPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = BeitConfig
|
||||
base_model_prefix = "beit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -590,6 +590,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
config_class = BeitConfig
|
||||
base_model_prefix = "beit"
|
||||
main_input_name = "pixel_values"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
||||
|
@ -789,6 +789,7 @@ class CLIPVisionTransformer(nn.Module):
|
||||
|
||||
class CLIPVisionModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
super().__init__(config)
|
||||
|
@ -653,6 +653,7 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
|
||||
config_class = CLIPVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
|
@ -385,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = DeiTConfig
|
||||
base_model_prefix = "deit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -784,6 +784,7 @@ class DetrClassificationHead(nn.Module):
|
||||
class DetrPreTrainedModel(PreTrainedModel):
|
||||
config_class = DetrConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
@ -776,6 +776,7 @@ class HubertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = HubertConfig
|
||||
base_model_prefix = "hubert"
|
||||
main_input_name = "input_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
|
@ -1265,6 +1265,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
|
||||
|
||||
config_class = HubertConfig
|
||||
base_model_prefix = "hubert"
|
||||
main_input_name = "input_values"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
|
@ -496,6 +496,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
|
||||
config_class = ImageGPTConfig
|
||||
load_tf_weights = load_tf_weights_in_imagegpt
|
||||
base_model_prefix = "transformer"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
|
@ -619,6 +619,7 @@ class PerceiverPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = PerceiverConfig
|
||||
base_model_prefix = "perceiver"
|
||||
main_input_name = "inputs"
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -406,6 +406,7 @@ class SegformerPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = SegformerConfig
|
||||
base_model_prefix = "segformer"
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -675,6 +675,7 @@ class SEWPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = SEWConfig
|
||||
base_model_prefix = "sew"
|
||||
main_input_name = "input_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
|
@ -1201,6 +1201,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = SEWDConfig
|
||||
base_model_prefix = "sew-d"
|
||||
main_input_name = "input_values"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
|
@ -180,6 +180,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
"""
|
||||
config_class = SpeechEncoderDecoderConfig
|
||||
base_model_prefix = "speech_encoder_decoder"
|
||||
main_input_name = "input_values"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -539,6 +539,7 @@ class Speech2TextDecoderLayer(nn.Module):
|
||||
class Speech2TextPreTrainedModel(PreTrainedModel):
|
||||
config_class = Speech2TextConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_features"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -912,6 +912,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = UniSpeechConfig
|
||||
base_model_prefix = "unispeech"
|
||||
main_input_name = "input_values"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
|
@ -947,6 +947,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = UniSpeechSatConfig
|
||||
base_model_prefix = "unispeech_sat"
|
||||
main_input_name = "input_values"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
|
@ -283,6 +283,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
config_class = VisionEncoderDecoderConfig
|
||||
base_model_prefix = "vision_encoder_decoder"
|
||||
main_input_name = "pixel_values"
|
||||
module_class = FlaxVisionEncoderDecoderModule
|
||||
|
||||
def __init__(
|
||||
|
@ -160,6 +160,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
||||
"""
|
||||
config_class = VisionEncoderDecoderConfig
|
||||
base_model_prefix = "vision_encoder_decoder"
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -406,6 +406,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
config_class = ViTConfig
|
||||
base_model_prefix = "vit"
|
||||
main_input_name = "pixel_values"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
||||
|
@ -555,6 +555,7 @@ class TFViTPreTrainedModel(TFPreTrainedModel):
|
||||
|
||||
config_class = ViTConfig
|
||||
base_model_prefix = "vit"
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
|
@ -412,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = ViTConfig
|
||||
base_model_prefix = "vit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -775,6 +775,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
config_class = Wav2Vec2Config
|
||||
base_model_prefix: str = "wav2vec2"
|
||||
main_input_name = "input_values"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
|
@ -1256,6 +1256,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
|
||||
|
||||
config_class = Wav2Vec2Config
|
||||
base_model_prefix = "wav2vec2"
|
||||
main_input_name = "input_values"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
|
@ -1044,6 +1044,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = Wav2Vec2Config
|
||||
base_model_prefix = "wav2vec2"
|
||||
main_input_name = "input_values"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
|
@ -996,6 +996,7 @@ class WavLMPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = WavLMConfig
|
||||
base_model_prefix = "wavlm"
|
||||
main_input_name = "input_values"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
|
@ -1315,6 +1315,13 @@ class ModelTesterMixin:
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
def test_model_main_input_name(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_signature = inspect.signature(getattr(model_class, "forward"))
|
||||
# The main input is the name of the argument after `self`
|
||||
observed_main_input_name = list(model_signature.parameters.keys())[1]
|
||||
self.assertEqual(model_class.main_input_name, observed_main_input_name)
|
||||
|
||||
def test_correct_missing_keys(self):
|
||||
if not self.test_missing_keys:
|
||||
return
|
||||
|
@ -778,6 +778,13 @@ class FlaxModelTesterMixin:
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
|
||||
def test_model_main_input_name(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_signature = inspect.signature(getattr(model_class, "__call__"))
|
||||
# The main input is the name of the argument after `self`
|
||||
observed_main_input_name = list(model_signature.parameters.keys())[1]
|
||||
self.assertEqual(model_class.main_input_name, observed_main_input_name)
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
|
@ -1183,6 +1183,13 @@ class TFModelTesterMixin:
|
||||
else:
|
||||
new_model_without_prefix(input_ids)
|
||||
|
||||
def test_model_main_input_name(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_signature = inspect.signature(getattr(model_class, "call"))
|
||||
# The main input is the name of the argument after `self`
|
||||
observed_main_input_name = list(model_signature.parameters.keys())[1]
|
||||
self.assertEqual(model_class.main_input_name, observed_main_input_name)
|
||||
|
||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||
# special tokens cannot be bad tokens
|
||||
special_tokens = []
|
||||
|
Loading…
Reference in New Issue
Block a user