mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add support for multiple models for one config in auto classes (#11150)
* Add support for multiple models for one config in auto classes * Use get_values everywhere * Prettier doc
This commit is contained in:
parent
97ccf67bb3
commit
ba8b1f4754
@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC):
|
||||
# get abs dir
|
||||
save_directory = os.path.abspath(save_directory)
|
||||
# save config as well
|
||||
self.config.architectures = [self.__class__.__name__[4:]]
|
||||
self.config.save_pretrained(save_directory)
|
||||
|
||||
# save model
|
||||
|
@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
logger.info(f"Saved model created in {saved_model_dir}")
|
||||
|
||||
# Save configuration file
|
||||
self.config.architectures = [self.__class__.__name__[2:]]
|
||||
self.config.save_pretrained(save_directory)
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
|
@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"auto_factory": ["get_values"],
|
||||
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
|
||||
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
|
||||
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
|
||||
@ -104,6 +105,7 @@ if is_flax_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .auto_factory import get_values
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
|
||||
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
|
@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
|
||||
"""
|
||||
|
||||
|
||||
def _get_model_class(config, model_mapping):
|
||||
supported_models = model_mapping[type(config)]
|
||||
if not isinstance(supported_models, (list, tuple)):
|
||||
return supported_models
|
||||
|
||||
name_to_model = {model.__name__: model for model in supported_models}
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in name_to_model:
|
||||
return name_to_model[arch]
|
||||
elif f"TF{arch}" in name_to_model:
|
||||
return name_to_model[f"TF{arch}"]
|
||||
elif f"Flax{arch}" in name_to_model:
|
||||
return name_to_model[f"Flax{arch}"]
|
||||
|
||||
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
|
||||
# defaults.
|
||||
return supported_models[0]
|
||||
|
||||
|
||||
class _BaseAutoModelClass:
|
||||
# Base class for auto models.
|
||||
_model_mapping = None
|
||||
@ -341,7 +361,8 @@ class _BaseAutoModelClass:
|
||||
|
||||
def from_config(cls, config, **kwargs):
|
||||
if type(config) in cls._model_mapping.keys():
|
||||
return cls._model_mapping[type(config)](config, **kwargs)
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
return model_class(config, **kwargs)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
@ -356,9 +377,8 @@ class _BaseAutoModelClass:
|
||||
)
|
||||
|
||||
if type(config) in cls._model_mapping.keys():
|
||||
return cls._model_mapping[type(config)].from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
|
||||
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
|
||||
new_class.from_pretrained = classmethod(from_pretrained)
|
||||
return new_class
|
||||
|
||||
|
||||
def get_values(model_mapping):
|
||||
result = []
|
||||
for model in model_mapping.values():
|
||||
if isinstance(model, (list, tuple)):
|
||||
result += list(model)
|
||||
else:
|
||||
result.append(model)
|
||||
|
||||
return result
|
||||
|
@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
)
|
||||
|
||||
|
||||
def _get_class_name(model_class):
|
||||
if isinstance(model_class, (list, tuple)):
|
||||
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
|
||||
return f":class:`~transformers.{model_class.__name__}`"
|
||||
|
||||
|
||||
def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
||||
if config_to_class is None and not use_model_types:
|
||||
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
|
||||
if use_model_types:
|
||||
if config_to_class is None:
|
||||
model_type_to_name = {model_type: config.__name__ for model_type, config in CONFIG_MAPPING.items()}
|
||||
model_type_to_name = {
|
||||
model_type: f":class:`~transformers.{config.__name__}`"
|
||||
for model_type, config in CONFIG_MAPPING.items()
|
||||
}
|
||||
else:
|
||||
model_type_to_name = {
|
||||
model_type: config_to_class[config].__name__
|
||||
model_type: _get_class_name(config_to_class[config])
|
||||
for model_type, config in CONFIG_MAPPING.items()
|
||||
if config in config_to_class
|
||||
}
|
||||
lines = [
|
||||
f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)"
|
||||
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
|
||||
for model_type in sorted(model_type_to_name.keys())
|
||||
]
|
||||
else:
|
||||
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
|
||||
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
|
||||
config_to_model_name = {
|
||||
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
|
||||
}
|
||||
lines = [
|
||||
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)"
|
||||
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
|
||||
for config_name in sorted(config_to_name.keys())
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import (
|
||||
)
|
||||
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
|
||||
from ..funnel.modeling_funnel import (
|
||||
FunnelBaseModel,
|
||||
FunnelForMaskedLM,
|
||||
FunnelForMultipleChoice,
|
||||
FunnelForPreTraining,
|
||||
@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
(CTRLConfig, CTRLModel),
|
||||
(ElectraConfig, ElectraModel),
|
||||
(ReformerConfig, ReformerModel),
|
||||
(FunnelConfig, FunnelModel),
|
||||
(FunnelConfig, (FunnelModel, FunnelBaseModel)),
|
||||
(LxmertConfig, LxmertModel),
|
||||
(BertGenerationConfig, BertGenerationEncoder),
|
||||
(DebertaConfig, DebertaModel),
|
||||
|
@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import (
|
||||
TFFlaubertWithLMHeadModel,
|
||||
)
|
||||
from ..funnel.modeling_tf_funnel import (
|
||||
TFFunnelBaseModel,
|
||||
TFFunnelForMaskedLM,
|
||||
TFFunnelForMultipleChoice,
|
||||
TFFunnelForPreTraining,
|
||||
@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict(
|
||||
(XLMConfig, TFXLMModel),
|
||||
(CTRLConfig, TFCTRLModel),
|
||||
(ElectraConfig, TFElectraModel),
|
||||
(FunnelConfig, TFFunnelModel),
|
||||
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
|
||||
(DPRConfig, TFDPRQuestionEncoder),
|
||||
(MPNetConfig, TFMPNetModel),
|
||||
(BartConfig, TFBartModel),
|
||||
|
@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -234,7 +235,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
@ -13,7 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
@ -46,6 +47,8 @@ if is_torch_available():
|
||||
BertForSequenceClassification,
|
||||
BertForTokenClassification,
|
||||
BertModel,
|
||||
FunnelBaseModel,
|
||||
FunnelModel,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
RobertaForMaskedLM,
|
||||
@ -218,6 +221,21 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertEqual(model.num_parameters(), 14410)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||
|
||||
def test_from_pretrained_with_tuple_values(self):
|
||||
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
|
||||
model = AutoModel.from_pretrained("sgugger/funnel-random-tiny")
|
||||
self.assertIsInstance(model, FunnelModel)
|
||||
|
||||
config = copy.deepcopy(model.config)
|
||||
config.architectures = ["FunnelBaseModel"]
|
||||
model = AutoModel.from_config(config)
|
||||
self.assertIsInstance(model, FunnelBaseModel)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
model = AutoModel.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(model, FunnelBaseModel)
|
||||
|
||||
def test_parents_and_children_in_mappings(self):
|
||||
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
||||
# by the parents and will return the wrong configuration type when using auto models
|
||||
@ -242,6 +260,12 @@ class AutoModelTest(unittest.TestCase):
|
||||
assert not issubclass(
|
||||
child_config, parent_config
|
||||
), f"{child_config.__name__} is child of {parent_config.__name__}"
|
||||
assert not issubclass(
|
||||
child_model, parent_model
|
||||
), f"{child_config.__name__} is child of {parent_config.__name__}"
|
||||
|
||||
# Tuplify child_model and parent_model since some of them could be tuples.
|
||||
if not isinstance(child_model, (list, tuple)):
|
||||
child_model = (child_model,)
|
||||
if not isinstance(parent_model, (list, tuple)):
|
||||
parent_model = (parent_model,)
|
||||
|
||||
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||
|
@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -444,7 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ import unittest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
@ -458,7 +459,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ from typing import List, Tuple
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
|
||||
|
||||
|
||||
@ -79,7 +80,7 @@ class ModelTesterMixin:
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict = {
|
||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||
if isinstance(v, torch.Tensor) and v.ndim > 1
|
||||
@ -88,9 +89,9 @@ class ModelTesterMixin:
|
||||
}
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
||||
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
inputs_dict["start_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
@ -98,18 +99,18 @@ class ModelTesterMixin:
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in [
|
||||
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
|
||||
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
|
||||
*MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.values(),
|
||||
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
|
||||
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in [
|
||||
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||
*MODEL_FOR_MASKED_LM_MAPPING.values(),
|
||||
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
@ -229,7 +230,7 @@ class ModelTesterMixin:
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class in MODEL_MAPPING.values():
|
||||
if model_class in get_values(MODEL_MAPPING):
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
@ -248,7 +249,7 @@ class ModelTesterMixin:
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class in MODEL_MAPPING.values():
|
||||
if model_class in get_values(MODEL_MAPPING):
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
@ -312,7 +313,7 @@ class ModelTesterMixin:
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
@ -19,6 +19,7 @@ import unittest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -352,7 +353,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -292,7 +293,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
@ -29,6 +29,7 @@ if is_flax_available():
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForSequenceClassification,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
@ -125,6 +126,7 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForSequenceClassification,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertForQuestionAnswering,
|
||||
)
|
||||
|
@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import FunnelTokenizer, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -365,7 +366,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -412,7 +413,7 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
@ -18,6 +18,7 @@ import copy
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -532,11 +533,11 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||
# special case for models like BERT that use multi-loss training for PreTraining
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
|
@ -21,6 +21,7 @@ import os
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -290,7 +291,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -272,7 +273,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
@ -32,6 +32,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_scatter, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -425,7 +426,7 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict = {
|
||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||
if isinstance(v, torch.Tensor) and v.ndim > 1
|
||||
@ -434,9 +435,9 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
}
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
||||
elif model_class in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.values():
|
||||
elif model_class in get_values(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
@ -457,17 +458,17 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester.batch_size, dtype=torch.float, device=torch_device
|
||||
)
|
||||
elif model_class in [
|
||||
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
|
||||
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
|
||||
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in [
|
||||
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||
*MODEL_FOR_MASKED_LM_MAPPING.values(),
|
||||
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
|
@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import AlbertConfig, is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -249,7 +250,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
|
||||
return inputs_dict
|
||||
|
@ -13,7 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_tf_available
|
||||
@ -39,6 +40,8 @@ if is_tf_available():
|
||||
TFBertForQuestionAnswering,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFFunnelBaseModel,
|
||||
TFFunnelModel,
|
||||
TFGPT2LMHeadModel,
|
||||
TFRobertaForMaskedLM,
|
||||
TFT5ForConditionalGeneration,
|
||||
@ -176,6 +179,21 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
self.assertEqual(model.num_parameters(), 14410)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||
|
||||
def test_from_pretrained_with_tuple_values(self):
|
||||
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
|
||||
model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
|
||||
self.assertIsInstance(model, TFFunnelModel)
|
||||
|
||||
config = copy.deepcopy(model.config)
|
||||
config.architectures = ["FunnelBaseModel"]
|
||||
model = TFAutoModel.from_config(config)
|
||||
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
model = TFAutoModel.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||
|
||||
def test_parents_and_children_in_mappings(self):
|
||||
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
||||
# by the parents and will return the wrong configuration type when using auto models
|
||||
@ -197,4 +215,12 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
for parent_config, parent_model in mapping[: index + 1]:
|
||||
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
|
||||
self.assertFalse(issubclass(child_config, parent_config))
|
||||
self.assertFalse(issubclass(child_model, parent_model))
|
||||
|
||||
# Tuplify child_model and parent_model since some of them could be tuples.
|
||||
if not isinstance(child_model, (list, tuple)):
|
||||
child_model = (child_model,)
|
||||
if not isinstance(parent_model, (list, tuple)):
|
||||
parent_model = (parent_model,)
|
||||
|
||||
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||
|
@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import BertConfig, is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -282,7 +283,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
|
||||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
|
||||
return inputs_dict
|
||||
|
@ -25,6 +25,7 @@ from importlib import import_module
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
_tf_gpu_memory_limit,
|
||||
is_pt_tf_cross_test,
|
||||
@ -89,7 +90,7 @@ class TFModelTesterMixin:
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict = {
|
||||
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
|
||||
if isinstance(v, tf.Tensor) and v.ndim > 0
|
||||
@ -98,21 +99,21 @@ class TFModelTesterMixin:
|
||||
}
|
||||
|
||||
if return_labels:
|
||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||
elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
|
||||
elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
|
||||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in [
|
||||
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
|
||||
*TF_MODEL_FOR_PRETRAINING_MAPPING.values(),
|
||||
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
||||
*get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = tf.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
|
||||
@ -580,7 +581,7 @@ class TFModelTesterMixin:
|
||||
),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
|
||||
}
|
||||
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
|
||||
else:
|
||||
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
|
||||
@ -796,9 +797,9 @@ class TFModelTesterMixin:
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
list_lm_models = (
|
||||
list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.values())
|
||||
+ list(TF_MODEL_FOR_MASKED_LM_MAPPING.values())
|
||||
+ list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values())
|
||||
get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
|
||||
+ get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
|
||||
+ get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@ -1128,7 +1129,7 @@ class TFModelTesterMixin:
|
||||
]
|
||||
loss_size = tf.size(added_label)
|
||||
|
||||
if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
|
||||
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
|
||||
# if loss is causal lm loss, labels are shift, so that one label per batch
|
||||
# is cut
|
||||
loss_size = loss_size - self.model_tester.batch_size
|
||||
|
@ -19,6 +19,8 @@ import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.models.auto import get_values
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_repo.py
|
||||
@ -86,7 +88,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"DPRReader",
|
||||
"DPRSpanPredictor",
|
||||
"FlaubertForQuestionAnswering",
|
||||
"FunnelBaseModel",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"OpenAIGPTDoubleHeadsModel",
|
||||
"RagModel",
|
||||
@ -95,7 +96,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"T5Stack",
|
||||
"TFDPRReader",
|
||||
"TFDPRSpanPredictor",
|
||||
"TFFunnelBaseModel",
|
||||
"TFGPT2DoubleHeadsModel",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
"TFRagModel",
|
||||
@ -153,7 +153,7 @@ def get_model_modules():
|
||||
def get_models(module):
|
||||
""" Get the objects in module that are models."""
|
||||
models = []
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel)
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||
for attr_name in dir(module):
|
||||
if "Pretrained" in attr_name or "PreTrained" in attr_name:
|
||||
continue
|
||||
@ -249,10 +249,13 @@ def get_all_auto_configured_models():
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
|
||||
result = result | set(getattr(transformers.models.auto.modeling_auto, attr_name).values())
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
||||
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
|
||||
result = result | set(getattr(transformers.models.auto.modeling_tf_auto, attr_name).values())
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
||||
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
||||
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
||||
return [cls.__name__ for cls in result]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user