Add support for multiple models for one config in auto classes (#11150)

* Add support for multiple models for one config in auto classes

* Use get_values everywhere

* Prettier doc
This commit is contained in:
Sylvain Gugger 2021-04-08 18:41:36 -04:00 committed by GitHub
parent 97ccf67bb3
commit ba8b1f4754
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 188 additions and 72 deletions

View File

@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC):
# get abs dir
save_directory = os.path.abspath(save_directory)
# save config as well
self.config.architectures = [self.__class__.__name__[4:]]
self.config.save_pretrained(save_directory)
# save model

View File

@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
logger.info(f"Saved model created in {saved_model_dir}")
# Save configuration file
self.config.architectures = [self.__class__.__name__[2:]]
self.config.save_pretrained(save_directory)
# If we save using the predefined names, we can load using `from_pretrained`

View File

@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
_import_structure = {
"auto_factory": ["get_values"],
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
@ -104,6 +105,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .auto_factory import get_values
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer

View File

@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
"""
def _get_model_class(config, model_mapping):
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models
name_to_model = {model.__name__: model for model in supported_models}
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return supported_models[0]
class _BaseAutoModelClass:
# Base class for auto models.
_model_mapping = None
@ -341,7 +361,8 @@ class _BaseAutoModelClass:
def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)](config, **kwargs)
model_class = _get_model_class(config, cls._model_mapping)
return model_class(config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
@ -356,9 +377,8 @@ class _BaseAutoModelClass:
)
if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
model_class = _get_model_class(config, cls._model_mapping)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
new_class.from_pretrained = classmethod(from_pretrained)
return new_class
def get_values(model_mapping):
result = []
for model in model_mapping.values():
if isinstance(model, (list, tuple)):
result += list(model)
else:
result.append(model)
return result

View File

@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict(
)
def _get_class_name(model_class):
if isinstance(model_class, (list, tuple)):
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
return f":class:`~transformers.{model_class.__name__}`"
def _list_model_options(indent, config_to_class=None, use_model_types=True):
if config_to_class is None and not use_model_types:
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
if use_model_types:
if config_to_class is None:
model_type_to_name = {model_type: config.__name__ for model_type, config in CONFIG_MAPPING.items()}
model_type_to_name = {
model_type: f":class:`~transformers.{config.__name__}`"
for model_type, config in CONFIG_MAPPING.items()
}
else:
model_type_to_name = {
model_type: config_to_class[config].__name__
model_type: _get_class_name(config_to_class[config])
for model_type, config in CONFIG_MAPPING.items()
if config in config_to_class
}
lines = [
f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)"
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
for model_type in sorted(model_type_to_name.keys())
]
else:
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
config_to_model_name = {
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
}
lines = [
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)"
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
for config_name in sorted(config_to_name.keys())
]
return "\n".join(lines)

View File

@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import (
)
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
from ..funnel.modeling_funnel import (
FunnelBaseModel,
FunnelForMaskedLM,
FunnelForMultipleChoice,
FunnelForPreTraining,
@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict(
(CTRLConfig, CTRLModel),
(ElectraConfig, ElectraModel),
(ReformerConfig, ReformerModel),
(FunnelConfig, FunnelModel),
(FunnelConfig, (FunnelModel, FunnelBaseModel)),
(LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder),
(DebertaConfig, DebertaModel),

View File

@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import (
TFFlaubertWithLMHeadModel,
)
from ..funnel.modeling_tf_funnel import (
TFFunnelBaseModel,
TFFunnelForMaskedLM,
TFFunnelForMultipleChoice,
TFFunnelForPreTraining,
@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict(
(XLMConfig, TFXLMModel),
(CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel),
(FunnelConfig, TFFunnelModel),
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
(DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel),

View File

@ -17,6 +17,7 @@
import unittest
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -234,7 +235,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tempfile
import unittest
from transformers import is_torch_available
@ -46,6 +47,8 @@ if is_torch_available():
BertForSequenceClassification,
BertForTokenClassification,
BertModel,
FunnelBaseModel,
FunnelModel,
GPT2Config,
GPT2LMHeadModel,
RobertaForMaskedLM,
@ -218,6 +221,21 @@ class AutoModelTest(unittest.TestCase):
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
def test_from_pretrained_with_tuple_values(self):
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
model = AutoModel.from_pretrained("sgugger/funnel-random-tiny")
self.assertIsInstance(model, FunnelModel)
config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = AutoModel.from_config(config)
self.assertIsInstance(model, FunnelBaseModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = AutoModel.from_pretrained(tmp_dir)
self.assertIsInstance(model, FunnelBaseModel)
def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models
@ -242,6 +260,12 @@ class AutoModelTest(unittest.TestCase):
assert not issubclass(
child_config, parent_config
), f"{child_config.__name__} is child of {parent_config.__name__}"
assert not issubclass(
child_model, parent_model
), f"{child_config.__name__} is child of {parent_config.__name__}"
# Tuplify child_model and parent_model since some of them could be tuples.
if not isinstance(child_model, (list, tuple)):
child_model = (child_model,)
if not isinstance(parent_model, (list, tuple)):
parent_model = (parent_model,)
for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"

View File

@ -17,6 +17,7 @@
import unittest
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -444,7 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)

View File

@ -19,6 +19,7 @@ import unittest
from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
from transformers.testing_utils import require_torch, slow, torch_device
@ -458,7 +459,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)

View File

@ -24,6 +24,7 @@ from typing import List, Tuple
from transformers import is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
@ -79,7 +80,7 @@ class ModelTesterMixin:
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim > 1
@ -88,9 +89,9 @@ class ModelTesterMixin:
}
if return_labels:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
@ -98,18 +99,18 @@ class ModelTesterMixin:
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in [
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
*MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.values(),
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in [
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*MODEL_FOR_MASKED_LM_MAPPING.values(),
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
@ -229,7 +230,7 @@ class ModelTesterMixin:
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values():
if model_class in get_values(MODEL_MAPPING):
continue
model = model_class(config)
model.to(torch_device)
@ -248,7 +249,7 @@ class ModelTesterMixin:
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values():
if model_class in get_values(MODEL_MAPPING):
continue
model = model_class(config)
model.to(torch_device)
@ -312,7 +313,7 @@ class ModelTesterMixin:
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned

View File

@ -19,6 +19,7 @@ import unittest
from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -352,7 +353,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned

View File

@ -17,6 +17,7 @@
import unittest
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -292,7 +293,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)

View File

@ -29,6 +29,7 @@ if is_flax_available():
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification,
FlaxBertModel,
)
@ -125,6 +126,7 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
FlaxBertForMultipleChoice,
FlaxBertForQuestionAnswering,
FlaxBertForNextSentencePrediction,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification,
FlaxBertForQuestionAnswering,
)

View File

@ -17,6 +17,7 @@
import unittest
from transformers import FunnelTokenizer, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -365,7 +366,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)

View File

@ -21,6 +21,7 @@ import unittest
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -412,7 +413,7 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned

View File

@ -18,6 +18,7 @@ import copy
import unittest
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -532,11 +533,11 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
# special case for models like BERT that use multi-loss training for PreTraining
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device

View File

@ -21,6 +21,7 @@ import os
import unittest
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -290,7 +291,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)

View File

@ -17,6 +17,7 @@
import unittest
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -272,7 +273,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)

View File

@ -32,6 +32,7 @@ from transformers import (
is_torch_available,
)
from transformers.file_utils import cached_property
from transformers.models.auto import get_values
from transformers.testing_utils import require_scatter, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -425,7 +426,7 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim > 1
@ -434,9 +435,9 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
}
if return_labels:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
elif model_class in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.values():
elif model_class in get_values(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING):
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
@ -457,17 +458,17 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.batch_size, dtype=torch.float, device=torch_device
)
elif model_class in [
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in [
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*MODEL_FOR_MASKED_LM_MAPPING.values(),
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device

View File

@ -17,6 +17,7 @@
import unittest
from transformers import AlbertConfig, is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
@ -249,7 +250,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
return inputs_dict

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tempfile
import unittest
from transformers import is_tf_available
@ -39,6 +40,8 @@ if is_tf_available():
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFBertModel,
TFFunnelBaseModel,
TFFunnelModel,
TFGPT2LMHeadModel,
TFRobertaForMaskedLM,
TFT5ForConditionalGeneration,
@ -176,6 +179,21 @@ class TFAutoModelTest(unittest.TestCase):
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
def test_from_pretrained_with_tuple_values(self):
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
self.assertIsInstance(model, TFFunnelModel)
config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = TFAutoModel.from_config(config)
self.assertIsInstance(model, TFFunnelBaseModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = TFAutoModel.from_pretrained(tmp_dir)
self.assertIsInstance(model, TFFunnelBaseModel)
def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models
@ -197,4 +215,12 @@ class TFAutoModelTest(unittest.TestCase):
for parent_config, parent_model in mapping[: index + 1]:
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
self.assertFalse(issubclass(child_config, parent_config))
self.assertFalse(issubclass(child_model, parent_model))
# Tuplify child_model and parent_model since some of them could be tuples.
if not isinstance(child_model, (list, tuple)):
child_model = (child_model,)
if not isinstance(parent_model, (list, tuple)):
parent_model = (parent_model,)
for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"

View File

@ -17,6 +17,7 @@
import unittest
from transformers import BertConfig, is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
@ -282,7 +283,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
return inputs_dict

View File

@ -25,6 +25,7 @@ from importlib import import_module
from typing import List, Tuple
from transformers import is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
@ -89,7 +90,7 @@ class TFModelTesterMixin:
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
inputs_dict = copy.deepcopy(inputs_dict)
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict = {
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
if isinstance(v, tf.Tensor) and v.ndim > 0
@ -98,21 +99,21 @@ class TFModelTesterMixin:
}
if return_labels:
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in [
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
*TF_MODEL_FOR_PRETRAINING_MAPPING.values(),
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
*get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING),
*get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
*get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
*get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
]:
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
@ -580,7 +581,7 @@ class TFModelTesterMixin:
),
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
}
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
else:
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
@ -796,9 +797,9 @@ class TFModelTesterMixin:
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
list_lm_models = (
list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.values())
+ list(TF_MODEL_FOR_MASKED_LM_MAPPING.values())
+ list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values())
get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
+ get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
+ get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)
)
for model_class in self.all_model_classes:
@ -1128,7 +1129,7 @@ class TFModelTesterMixin:
]
loss_size = tf.size(added_label)
if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
# if loss is causal lm loss, labels are shift, so that one label per batch
# is cut
loss_size = loss_size - self.model_tester.batch_size

View File

@ -19,6 +19,8 @@ import os
import re
from pathlib import Path
from transformers.models.auto import get_values
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_repo.py
@ -86,7 +88,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering",
"FunnelBaseModel",
"GPT2DoubleHeadsModel",
"OpenAIGPTDoubleHeadsModel",
"RagModel",
@ -95,7 +96,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"T5Stack",
"TFDPRReader",
"TFDPRSpanPredictor",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel",
"TFRagModel",
@ -153,7 +153,7 @@ def get_model_modules():
def get_models(module):
""" Get the objects in module that are models."""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel)
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
if "Pretrained" in attr_name or "PreTrained" in attr_name:
continue
@ -249,10 +249,13 @@ def get_all_auto_configured_models():
result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.models.auto.modeling_auto, attr_name).values())
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.models.auto.modeling_tf_auto, attr_name).values())
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
return [cls.__name__ for cls in result]