diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index 1f871bf71d3..e6b1a07a6b4 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -159,10 +159,14 @@ class AlbertConfig(PretrainedConfig): class AlbertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), - ("token_type_ids", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), ] ) diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index 908c6cd432c..893e6fb6d82 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -160,10 +160,14 @@ class BertConfig(PretrainedConfig): class BertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), - ("token_type_ids", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), ] ) diff --git a/src/transformers/models/big_bird/configuration_big_bird.py b/src/transformers/models/big_bird/configuration_big_bird.py index 15efd9c2c8f..371846982fd 100644 --- a/src/transformers/models/big_bird/configuration_big_bird.py +++ b/src/transformers/models/big_bird/configuration_big_bird.py @@ -168,9 +168,13 @@ class BigBirdConfig(PretrainedConfig): class BigBirdOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), ] ) diff --git a/src/transformers/models/camembert/configuration_camembert.py b/src/transformers/models/camembert/configuration_camembert.py index a65ebd7c448..47d2b3086f0 100644 --- a/src/transformers/models/camembert/configuration_camembert.py +++ b/src/transformers/models/camembert/configuration_camembert.py @@ -44,9 +44,13 @@ class CamembertConfig(RobertaConfig): class CamembertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), ] ) diff --git a/src/transformers/models/data2vec/configuration_data2vec_text.py b/src/transformers/models/data2vec/configuration_data2vec_text.py index 0356e6c0987..3258ec716b2 100644 --- a/src/transformers/models/data2vec/configuration_data2vec_text.py +++ b/src/transformers/models/data2vec/configuration_data2vec_text.py @@ -139,9 +139,13 @@ class Data2VecTextConfig(PretrainedConfig): class Data2VecTextOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), ] ) diff --git a/src/transformers/models/distilbert/configuration_distilbert.py b/src/transformers/models/distilbert/configuration_distilbert.py index 05027e31782..59752bbe7e1 100644 --- a/src/transformers/models/distilbert/configuration_distilbert.py +++ b/src/transformers/models/distilbert/configuration_distilbert.py @@ -134,9 +134,13 @@ class DistilBertConfig(PretrainedConfig): class DistilBertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), ] ) diff --git a/src/transformers/models/electra/configuration_electra.py b/src/transformers/models/electra/configuration_electra.py index 9b4525d3dce..765498ef833 100644 --- a/src/transformers/models/electra/configuration_electra.py +++ b/src/transformers/models/electra/configuration_electra.py @@ -179,10 +179,14 @@ class ElectraConfig(PretrainedConfig): class ElectraOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), - ("token_type_ids", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), ] ) diff --git a/src/transformers/models/flaubert/configuration_flaubert.py b/src/transformers/models/flaubert/configuration_flaubert.py index e4ec6414c29..f9913336d3d 100644 --- a/src/transformers/models/flaubert/configuration_flaubert.py +++ b/src/transformers/models/flaubert/configuration_flaubert.py @@ -146,9 +146,13 @@ class FlaubertConfig(XLMConfig): class FlaubertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), ] ) diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index a34c1d5c5c9..a840ba14129 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -234,7 +234,7 @@ class GPT2OnnxConfig(OnnxConfigWithPast): framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework ) # We need to order the input in the way they appears in the forward() diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py index 1e453de3f85..cc2eab23378 100644 --- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py +++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -233,7 +233,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework ) # We need to order the input in the way they appears in the forward() diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py index 25a193cdd91..620e6a60ab6 100644 --- a/src/transformers/models/gptj/configuration_gptj.py +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -183,7 +183,7 @@ class GPTJOnnxConfig(OnnxConfigWithPast): framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework ) # We need to order the input in the way they appears in the forward() diff --git a/src/transformers/models/ibert/configuration_ibert.py b/src/transformers/models/ibert/configuration_ibert.py index 8b96594cfee..cdbd8310312 100644 --- a/src/transformers/models/ibert/configuration_ibert.py +++ b/src/transformers/models/ibert/configuration_ibert.py @@ -131,9 +131,13 @@ class IBertConfig(PretrainedConfig): class IBertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), ] ) diff --git a/src/transformers/models/layoutlm/configuration_layoutlm.py b/src/transformers/models/layoutlm/configuration_layoutlm.py index 6b7e0537e0c..44ffa4375f1 100644 --- a/src/transformers/models/layoutlm/configuration_layoutlm.py +++ b/src/transformers/models/layoutlm/configuration_layoutlm.py @@ -171,7 +171,9 @@ class LayoutLMOnnxConfig(OnnxConfig): Mapping[str, Tensor] holding the kwargs to provide to the model's forward function """ - input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) + input_dict = super().generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) # Generate a dummy bbox box = [48, 84, 73, 128] diff --git a/src/transformers/models/roberta/configuration_roberta.py b/src/transformers/models/roberta/configuration_roberta.py index 6c54cf7ccd5..e20bf36bedc 100644 --- a/src/transformers/models/roberta/configuration_roberta.py +++ b/src/transformers/models/roberta/configuration_roberta.py @@ -70,9 +70,13 @@ class RobertaConfig(BertConfig): class RobertaOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), ] ) diff --git a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py index 568aedcfca0..e8b998e0f0d 100644 --- a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py @@ -47,9 +47,13 @@ class XLMRobertaConfig(RobertaConfig): class XLMRobertaOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), ] ) diff --git a/src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py index 14e3ffc2eec..a80b703f660 100644 --- a/src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py @@ -143,9 +143,13 @@ class XLMRobertaXLConfig(PretrainedConfig): class XLMRobertaXLOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), ] ) diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 8a7437c45fc..8f886a5d7a4 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -71,6 +71,7 @@ class OnnxConfig(ABC): default_fixed_batch = 2 default_fixed_sequence = 8 + default_fixed_num_choices = 4 torch_onnx_minimum_version = version.parse("1.8") _tasks_to_common_outputs = { "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), @@ -174,6 +175,16 @@ class OnnxConfig(ABC): """ return OnnxConfig.default_fixed_sequence + @property + def default_num_choices(self) -> int: + """ + The default number of choices to use if no other indication + + Returns: + Integer > 0 + """ + return OnnxConfig.default_fixed_num_choices + @property def default_onnx_opset(self) -> int: """ @@ -240,6 +251,7 @@ class OnnxConfig(ABC): preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], batch_size: int = -1, seq_length: int = -1, + num_choices: int = -1, is_pair: bool = False, framework: Optional[TensorType] = None, num_channels: int = 3, @@ -255,6 +267,8 @@ class OnnxConfig(ABC): The preprocessor associated with this model configuration. batch_size (`int`, *optional*, defaults to -1): The batch size to export the model for (-1 means dynamic axis). + num_choices (`int`, *optional*, defaults to -1): + The number of candidate answers provided for multiple choice task (-1 means dynamic axis). seq_length (`int`, *optional*, defaults to -1): The sequence length to export the model for (-1 means dynamic axis). is_pair (`bool`, *optional*, defaults to `False`): @@ -295,6 +309,19 @@ class OnnxConfig(ABC): ) # Generate dummy inputs according to compute batch and sequence dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size + if self.task == "multiple-choice": + # If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations + # made by ONNX + num_choices = compute_effective_axis_dimension( + num_choices, fixed_dimension=OnnxConfig.default_fixed_num_choices, num_token_to_add=0 + ) + dummy_input = dummy_input * num_choices + # The shape of the tokenized inputs values is [batch_size * num_choices, seq_length] + tokenized_input = preprocessor(dummy_input, text_pair=dummy_input) + # Unflatten the tokenized inputs values expanding it to the shape [batch_size, num_choices, seq_length] + for k, v in tokenized_input.items(): + tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)] + return dict(tokenized_input.convert_to_tensors(tensor_type=framework)) return dict(preprocessor(dummy_input, return_tensors=framework)) elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX @@ -408,7 +435,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC): ) -> Mapping[str, Any]: # TODO: should we set seq_length = 1 when self.use_past = True? - common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) + common_inputs = super().generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) if self.use_past: if not is_torch_available(): @@ -527,13 +556,13 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast): ) -> Mapping[str, Any]: encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework ) # Generate decoder inputs decoder_seq_length = seq_length if not self.use_past else 1 decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, decoder_seq_length, is_pair, framework + tokenizer, batch_size=batch_size, seq_length=decoder_seq_length, is_pair=is_pair, framework=framework ) decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} common_inputs = dict(**encoder_inputs, **decoder_inputs) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 3875da445f1..5a64097e8ca 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.camembert import CamembertOnnxConfig +from ..models.data2vec import Data2VecTextOnnxConfig from ..models.distilbert import DistilBertOnnxConfig from ..models.electra import ElectraOnnxConfig from ..models.flaubert import FlaubertOnnxConfig @@ -120,7 +121,7 @@ class FeaturesManager: "default", "masked-lm", "sequence-classification", - # "multiple-choice", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=AlbertOnnxConfig, @@ -152,7 +153,7 @@ class FeaturesManager: "masked-lm", "causal-lm", "sequence-classification", - # "multiple-choice", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=BertOnnxConfig, @@ -162,6 +163,7 @@ class FeaturesManager: "masked-lm", "causal-lm", "sequence-classification", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=BigBirdOnnxConfig, @@ -170,7 +172,7 @@ class FeaturesManager: "default", "masked-lm", "sequence-classification", - # "multiple-choice", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=IBertOnnxConfig, @@ -180,7 +182,7 @@ class FeaturesManager: "masked-lm", "causal-lm", "sequence-classification", - # "multiple-choice", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=CamembertOnnxConfig, @@ -189,7 +191,7 @@ class FeaturesManager: "default", "masked-lm", "sequence-classification", - # "multiple-choice", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=DistilBertOnnxConfig, @@ -199,6 +201,7 @@ class FeaturesManager: "masked-lm", "causal-lm", "sequence-classification", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=FlaubertOnnxConfig, @@ -220,7 +223,7 @@ class FeaturesManager: "masked-lm", "causal-lm", "sequence-classification", - # "multiple-choice", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=RobertaOnnxConfig, @@ -233,7 +236,7 @@ class FeaturesManager: "masked-lm", "causal-lm", "sequence-classification", - # "multiple-choice", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=XLMRobertaOnnxConfig, @@ -276,6 +279,7 @@ class FeaturesManager: "masked-lm", "causal-lm", "sequence-classification", + "multiple-choice", "token-classification", "question-answering", onnx_config_cls=ElectraOnnxConfig, @@ -300,6 +304,15 @@ class FeaturesManager: "seq2seq-lm-with-past", onnx_config_cls=BlenderbotSmallOnnxConfig, ), + "data2vec-text": supported_features_mapping( + "default", + "masked-lm", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=Data2VecTextOnnxConfig, + ), } AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 1ddaa78ce6c..9e3ee736165 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = { ("layoutlm", "microsoft/layoutlm-base-uncased"), ("vit", "google/vit-base-patch16-224"), ("beit", "microsoft/beit-base-patch16-224"), + ("data2vec-text", "facebook/data2vec-text-base"), } PYTORCH_EXPORT_WITH_PAST_MODELS = {