From 13504dcbea231d2cae701d1ffdeb0810d62aff81 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 22 Dec 2021 14:43:11 +0100 Subject: [PATCH] Onnx enable tasks for supported models (part 2) (#14700) * Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)" This reverts commit 0f4e39c5599523c110cd713f60a3bfa145dad807. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Implemented suggestions * Fixed __init__ to resolve conflict with master * Remove commented import --- .../models/albert/configuration_albert.py | 4 - .../models/bart/configuration_bart.py | 236 +++++++++++++-- .../models/bert/configuration_bert.py | 4 - .../distilbert/configuration_distilbert.py | 4 - .../models/gpt2/configuration_gpt2.py | 66 ++-- .../models/gpt_neo/configuration_gpt_neo.py | 53 ++-- .../models/mbart/configuration_mbart.py | 238 +++++++++++++-- .../models/roberta/configuration_roberta.py | 4 - .../models/t5/configuration_t5.py | 113 ++----- .../xlm_roberta/configuration_xlm_roberta.py | 4 - src/transformers/onnx/__init__.py | 16 +- src/transformers/onnx/__main__.py | 10 +- src/transformers/onnx/config.py | 282 ++++++++++++++++-- src/transformers/onnx/convert.py | 2 +- src/transformers/onnx/features.py | 230 +++++++++++--- tests/test_onnx_v2.py | 194 ++++++------ 16 files changed, 1064 insertions(+), 396 deletions(-) diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index 4f9b6be85e0..1bd0aa786bf 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -166,7 +166,3 @@ class AlbertOnnxConfig(OnnxConfig): ("token_type_ids", {0: "batch", 1: "sequence"}), ] ) - - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 3e978bba50e..05854e9bb7b 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -15,10 +15,13 @@ """ BART model configuration """ import warnings from collections import OrderedDict -from typing import Mapping +from typing import Any, Mapping, Optional +from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig -from ...onnx import OnnxConfigWithPast +from ...file_utils import TensorType, is_torch_available +from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast +from ...onnx.utils import compute_effective_axis_dimension from ...utils import logging @@ -180,30 +183,223 @@ class BartConfig(PretrainedConfig): ) -class BartOnnxConfig(OnnxConfigWithPast): +class BartOnnxConfig(OnnxSeq2SeqConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict( - [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), - ] - ) + if self.task in ["default", "seq2seq-lm"]: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + elif self.task == "causal-lm": + # TODO: figure this case out. + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + else: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), + ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), + ] + ) + + return common_inputs @property def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_outputs = super().outputs + else: + common_outputs = super(OnnxConfigWithPast, self).outputs + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + return common_outputs + + def _generate_dummy_inputs_for_default_and_seq2seq_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + # Generate decoder inputs + decoder_seq_length = seq_length if not self.use_past else 1 + decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, decoder_seq_length, is_pair, framework + ) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + common_inputs = dict(**encoder_inputs, **decoder_inputs) + if self.use_past: - return OrderedDict( - [ - ("last_hidden_state", {0: "batch", 1: "sequence"}), - ("past_keys", {0: "batch", 2: "sequence"}), - ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), - ] + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, encoder_seq_length = common_inputs["input_ids"].shape + decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] + num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads + encoder_shape = ( + batch, + num_encoder_attention_heads, + encoder_seq_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + decoder_past_length = decoder_seq_length + 3 + decoder_shape = ( + batch, + num_decoder_attention_heads, + decoder_past_length, + self._config.hidden_size // num_decoder_attention_heads, + ) + + common_inputs["decoder_attention_mask"] = torch.cat( + [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 + ) + + common_inputs["past_key_values"] = [] + # If the number of encoder and decoder layers are present in the model configuration, both are considered + num_encoder_layers, num_decoder_layers = self.num_layers + min_num_layers = min(num_encoder_layers, num_decoder_layers) + max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers + remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" + + for _ in range(min_num_layers): + common_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + # TODO: test this. + shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape + for _ in range(min_num_layers, max_num_layers): + common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) + return common_inputs + + def _generate_dummy_inputs_for_causal_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + num_encoder_layers, _ = self.num_layers + num_encoder_attention_heads, _ = self.num_attention_heads + past_shape = ( + batch, + num_encoder_attention_heads, + past_key_values_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + + common_inputs["attention_mask"] = torch.cat( + [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 + ) + common_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) + ] + return common_inputs + + def _generate_dummy_inputs_for_sequence_classification_and_question_answering( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + # Copied from OnnxConfig.generate_dummy_inputs + # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0 + ) + + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add + ) + + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + return common_inputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + elif self.task == "causal-lm": + common_inputs = self._generate_dummy_inputs_for_causal_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework ) else: - return OrderedDict( - [ - ("last_hidden_state", {0: "batch", 1: "sequence"}), - ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), - ] + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + return common_inputs + + def _flatten_past_key_values_(self, flattened_output, name, idx, t): + if self.task in ["default", "seq2seq-lm"]: + flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) + else: + flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( + flattened_output, name, idx, t ) diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index a3a3ef5ac82..885285dfa36 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -168,7 +168,3 @@ class BertOnnxConfig(OnnxConfig): ("token_type_ids", {0: "batch", 1: "sequence"}), ] ) - - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/distilbert/configuration_distilbert.py b/src/transformers/models/distilbert/configuration_distilbert.py index 09ffe1619cd..36e47ddea32 100644 --- a/src/transformers/models/distilbert/configuration_distilbert.py +++ b/src/transformers/models/distilbert/configuration_distilbert.py @@ -142,7 +142,3 @@ class DistilBertOnnxConfig(OnnxConfig): ("attention_mask", {0: "batch", 1: "sequence"}), ] ) - - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})]) diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index 9ea843a5231..d119fb955bf 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -15,12 +15,12 @@ # limitations under the License. """ OpenAI GPT-2 configuration """ from collections import OrderedDict -from typing import Any, Mapping, Optional +from typing import Any, List, Mapping, Optional from transformers import PreTrainedTokenizer, TensorType, is_torch_available from ...configuration_utils import PretrainedConfig -from ...onnx import OnnxConfigWithPast +from ...onnx import OnnxConfigWithPast, PatchingSpec from ...utils import logging @@ -195,29 +195,36 @@ class GPT2Config(PretrainedConfig): class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + @property def inputs(self) -> Mapping[str, Mapping[int, str]]: - common_inputs = OrderedDict({"input_ids": {0: "batch"}}) + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) if self.use_past: - for i in range(self._config.n_layer * 2): - common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"} - - common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} else: common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} return common_inputs @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}) - if self.use_past: - for i in range(self._config.n_layer * 2): - common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"} + def num_layers(self) -> int: + return self._config.n_layer - return common_outputs - - return common_outputs + @property + def num_attention_heads(self) -> int: + return self._config.n_head def generate_dummy_inputs( self, @@ -227,7 +234,9 @@ class GPT2OnnxConfig(OnnxConfigWithPast): is_pair: bool = False, framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: - common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size, seq_length, is_pair, framework + ) # We need to order the input in the way they appears in the forward() ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) @@ -239,14 +248,27 @@ class GPT2OnnxConfig(OnnxConfigWithPast): else: import torch - batch = common_inputs["input_ids"].shape[0] + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) ordered_inputs["past_key_values"] = [ - ( - torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)), - torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)), - ) - for _ in range(self._config.n_layer) + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) ] ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 + ) + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py index 5499334c871..3b40bd72b52 100644 --- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py +++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -15,7 +15,7 @@ """ GPT Neo model configuration """ from collections import OrderedDict -from typing import Any, Dict, Iterable, Mapping, Optional +from typing import Any, Mapping, Optional from ... import PreTrainedTokenizer, TensorType, is_torch_available from ...configuration_utils import PretrainedConfig @@ -211,10 +211,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) if self.use_past: - for i in range(self._config.num_layers): - common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence"} - common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence"} - + self.fill_with_past_key_values_(common_inputs, direction="inputs") common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} else: common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} @@ -222,16 +219,8 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): return common_inputs @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = super().outputs - if self.use_past: - for i in range(self._config.num_layers): - common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} - common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} - - return common_outputs - - return common_outputs + def num_attention_heads(self) -> int: + return self._config.num_heads def generate_dummy_inputs( self, @@ -241,7 +230,10 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): is_pair: bool = False, framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: - common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) + + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size, seq_length, is_pair, framework + ) # We need to order the input in the way they appears in the forward() ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) @@ -253,28 +245,27 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): else: import torch - batch = common_inputs["input_ids"].shape[0] - past_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads) + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) ordered_inputs["past_key_values"] = [ - (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self._config.num_layers) + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) ] ordered_inputs["attention_mask"] = common_inputs["attention_mask"] if self.use_past: ordered_inputs["attention_mask"] = torch.cat( - [ordered_inputs["attention_mask"], torch.ones(batch, 1)], dim=1 + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 ) return ordered_inputs - @staticmethod - def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: - if name in ["present", "past_key_values"]: - flatten_output = {} - for idx, t in enumerate(field): - flatten_output[f"{name}.{idx}.key"] = t[0] - flatten_output[f"{name}.{idx}.value"] = t[1] - - return flatten_output - - return super().flatten_output_collection_property(name, field) + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index 2e4769583ff..f1f08cd75d4 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -14,11 +14,13 @@ # limitations under the License. """ MBART model configuration """ from collections import OrderedDict -from typing import Mapping - -from transformers.onnx import OnnxConfigWithPast +from typing import Any, Mapping, Optional +from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig +from ...file_utils import TensorType, is_torch_available +from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast +from ...onnx.utils import compute_effective_axis_dimension from ...utils import logging @@ -165,30 +167,224 @@ class MBartConfig(PretrainedConfig): ) -class MBartOnnxConfig(OnnxConfigWithPast): +# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->MBart +class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict( - [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), - ] - ) + if self.task in ["default", "seq2seq-lm"]: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + elif self.task == "causal-lm": + # TODO: figure this case out. + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + else: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), + ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), + ] + ) + + return common_inputs @property def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_outputs = super().outputs + else: + common_outputs = super(OnnxConfigWithPast, self).outputs + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + return common_outputs + + def _generate_dummy_inputs_for_default_and_seq2seq_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + # Generate decoder inputs + decoder_seq_length = seq_length if not self.use_past else 1 + decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, decoder_seq_length, is_pair, framework + ) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + common_inputs = dict(**encoder_inputs, **decoder_inputs) + if self.use_past: - return OrderedDict( - [ - ("last_hidden_state", {0: "batch", 1: "sequence"}), - ("past_keys", {0: "batch", 2: "sequence"}), - ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), - ] + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, encoder_seq_length = common_inputs["input_ids"].shape + decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] + num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads + encoder_shape = ( + batch, + num_encoder_attention_heads, + encoder_seq_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + decoder_past_length = decoder_seq_length + 3 + decoder_shape = ( + batch, + num_decoder_attention_heads, + decoder_past_length, + self._config.hidden_size // num_decoder_attention_heads, + ) + + common_inputs["decoder_attention_mask"] = torch.cat( + [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 + ) + + common_inputs["past_key_values"] = [] + # If the number of encoder and decoder layers are present in the model configuration, both are considered + num_encoder_layers, num_decoder_layers = self.num_layers + min_num_layers = min(num_encoder_layers, num_decoder_layers) + max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers + remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" + + for _ in range(min_num_layers): + common_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + # TODO: test this. + shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape + for _ in range(min_num_layers, max_num_layers): + common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) + return common_inputs + + def _generate_dummy_inputs_for_causal_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + num_encoder_layers, _ = self.num_layers + num_encoder_attention_heads, _ = self.num_attention_heads + past_shape = ( + batch, + num_encoder_attention_heads, + past_key_values_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + + common_inputs["attention_mask"] = torch.cat( + [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 + ) + common_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) + ] + return common_inputs + + def _generate_dummy_inputs_for_sequence_classification_and_question_answering( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + # Copied from OnnxConfig.generate_dummy_inputs + # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0 + ) + + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add + ) + + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + return common_inputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + elif self.task == "causal-lm": + common_inputs = self._generate_dummy_inputs_for_causal_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework ) else: - return OrderedDict( - [ - ("last_hidden_state", {0: "batch", 1: "sequence"}), - ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), - ] + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + return common_inputs + + def _flatten_past_key_values_(self, flattened_output, name, idx, t): + if self.task in ["default", "seq2seq-lm"]: + flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) + else: + flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( + flattened_output, name, idx, t ) diff --git a/src/transformers/models/roberta/configuration_roberta.py b/src/transformers/models/roberta/configuration_roberta.py index db462b17c20..dd3697a5f7a 100644 --- a/src/transformers/models/roberta/configuration_roberta.py +++ b/src/transformers/models/roberta/configuration_roberta.py @@ -77,7 +77,3 @@ class RobertaOnnxConfig(OnnxConfig): ("attention_mask", {0: "batch", 1: "sequence"}), ] ) - - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index c4e386fd357..557ff5e187a 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -13,14 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """ T5 model configuration """ -from collections import OrderedDict -from typing import Any, Dict, Iterable, Mapping, Optional +from typing import Mapping -from transformers import PreTrainedTokenizer, TensorType - -from ... import is_torch_available from ...configuration_utils import PretrainedConfig -from ...onnx import OnnxConfigWithPast +from ...onnx import OnnxSeq2SeqConfigWithPast from ...utils import logging @@ -124,101 +120,26 @@ class T5Config(PretrainedConfig): ) -class T5OnnxConfig(OnnxConfigWithPast): +class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: - common_inputs = OrderedDict( - [ - ("input_ids", {0: "batch", 1: "encoder_sequence"}), - ("attention_mask", {0: "batch", 1: "encoder_sequence"}), - ("decoder_input_ids", {0: "batch"}), - ("decoder_attention_mask", {0: "batch"}), - ] - ) + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} if self.use_past: - for i in range(0, self._config.num_layers): - common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"} - common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"} - common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"} - common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"} + self.fill_with_past_key_values_(common_inputs, direction="inputs") return common_inputs @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = super().outputs - - if "last_hidden_state" in common_outputs: - common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"} - - if self.use_past: - for i in range(self._config.num_layers): - common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"} - common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"} - common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"} - common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"} - - if self.task == "default": - common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"} - - return common_outputs - - def generate_dummy_inputs( - self, - tokenizer: PreTrainedTokenizer, - batch_size: int = -1, - seq_length: int = -1, - is_pair: bool = False, - framework: Optional[TensorType] = None, - ) -> Mapping[str, Any]: - - # Generate encoder inputs - encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) - - # Generate decoder inputs - decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework) - decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} - - ordered_inputs = dict(**encoder_inputs, **decoder_inputs) - if self.use_past: - if not is_torch_available(): - raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") - else: - import torch - batch = encoder_inputs["input_ids"].shape[0] - encoder_seq_length = encoder_inputs["input_ids"].shape[1] - encoder_shape = ( - batch, - self._config.num_heads, - encoder_seq_length, - self._config.hidden_size // self._config.num_heads, - ) - decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads) - - ordered_inputs["past_key_values"] = [] - for _ in range(self._config.num_layers): - ordered_inputs["past_key_values"].append( - ( - torch.zeros(decoder_shape), - torch.zeros(decoder_shape), - torch.zeros(encoder_shape), - torch.zeros(encoder_shape), - ) - ) - - return ordered_inputs - - @staticmethod - def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: - if name in ["present", "past_key_values"]: - flatten_output = {} - for idx, t in enumerate(field): - flatten_output[f"{name}.{idx}.decoder.key"] = t[0] - flatten_output[f"{name}.{idx}.decoder.value"] = t[1] - flatten_output[f"{name}.{idx}.encoder.key"] = t[2] - flatten_output[f"{name}.{idx}.encoder.value"] = t[3] - - return flatten_output - - return super().flatten_output_collection_property(name, field) + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py index e0974a52e0a..6ee2d52e865 100644 --- a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py @@ -53,7 +53,3 @@ class XLMRobertaOnnxConfig(OnnxConfig): ("attention_mask", {0: "batch", 1: "sequence"}), ] ) - - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/onnx/__init__.py b/src/transformers/onnx/__init__.py index 7419e8c21b5..8c146cee745 100644 --- a/src/transformers/onnx/__init__.py +++ b/src/transformers/onnx/__init__.py @@ -19,14 +19,26 @@ from ..file_utils import _LazyModule _import_structure = { - "config": ["EXTERNAL_DATA_FORMAT_SIZE_LIMIT", "OnnxConfig", "OnnxConfigWithPast", "PatchingSpec"], + "config": [ + "EXTERNAL_DATA_FORMAT_SIZE_LIMIT", + "OnnxConfig", + "OnnxConfigWithPast", + "OnnxSeq2SeqConfigWithPast", + "PatchingSpec", + ], "convert": ["export", "validate_model_outputs"], "utils": ["ParameterFormat", "compute_serialized_parameters_size"], } if TYPE_CHECKING: - from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec + from .config import ( + EXTERNAL_DATA_FORMAT_SIZE_LIMIT, + OnnxConfig, + OnnxConfigWithPast, + OnnxSeq2SeqConfigWithPast, + PatchingSpec, + ) from .convert import export, validate_model_outputs from .utils import ParameterFormat, compute_serialized_parameters_size diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py index be724423316..eb5d2773b0d 100644 --- a/src/transformers/onnx/__main__.py +++ b/src/transformers/onnx/__main__.py @@ -32,10 +32,10 @@ def main(): help="Export the model with some additional feature.", ) parser.add_argument( - "--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)." + "--opset", type=int, default=None, help="ONNX opset version to export the model with (default 12)." ) parser.add_argument( - "--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model." + "--atol", type=float, default=None, help="Absolute difference tolerence when validating the model." ) parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") @@ -53,6 +53,9 @@ def main(): onnx_config = model_onnx_config(model.config) # Ensure the requested opset is sufficient + if args.opset is None: + args.opset = onnx_config.default_onnx_opset + if args.opset < onnx_config.default_onnx_opset: raise ValueError( f"Opset {args.opset} is not sufficient to export {model_kind}. " @@ -61,6 +64,9 @@ def main(): onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output) + if args.atol is None: + args.atol = onnx_config.atol_for_validation + validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol) logger.info(f"All good, model saved at: {args.output.as_posix()}") diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 8e9e1575b1e..65cedbaa591 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import dataclasses from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple -from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType +from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType, is_torch_available from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size @@ -58,6 +59,7 @@ class OnnxConfig(ABC): _TASKS_TO_COMMON_OUTPUTS = { "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), + "masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}), "sequence-classification": OrderedDict({"logits": {0: "batch"}}), @@ -119,7 +121,8 @@ class OnnxConfig(ABC): Returns: For each output: its name associated to the axes symbolic name and the axis position within the tensor """ - return self._TASKS_TO_COMMON_OUTPUTS[self.task] + common_outputs = self._TASKS_TO_COMMON_OUTPUTS[self.task] + return copy.deepcopy(common_outputs) @property def values_override(self) -> Optional[Mapping[str, Any]]: @@ -165,6 +168,16 @@ class OnnxConfig(ABC): """ return DEFAULT_ONNX_OPSET + @property + def atol_for_validation(self) -> float: + """ + What absolute tolerance value to use during model conversion validation. + + Returns: + Float absolute tolerance value. + """ + return 1e-5 + @staticmethod def use_external_data_format(num_parameters: int) -> bool: """ @@ -229,8 +242,8 @@ class OnnxConfig(ABC): orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op) setattr(spec.o, spec.name, orig_op) - @staticmethod - def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: + @classmethod + def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> Dict[str, Any]: """ Flatten any potential nested structure expanding the name of the field with the index of the element within the structure. @@ -272,6 +285,14 @@ class OnnxConfigWithPast(OnnxConfig, ABC): """ return cls(config, task=task, use_past=True) + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = super().outputs + if self.use_past: + self.fill_with_past_key_values_(common_outputs, direction="outputs") + + return common_outputs + @property def values_override(self) -> Optional[Mapping[str, Any]]: if hasattr(self._config, "use_cache"): @@ -279,6 +300,30 @@ class OnnxConfigWithPast(OnnxConfig, ABC): return None + @property + def num_layers(self) -> int: + """ + The number of layers attribute retrieved from the model config. Override this for model configs where the + number of layers attribute is not called `num_layers`. + """ + if not hasattr(self._config, "num_layers"): + raise AttributeError( + "could not find the number of layers attribute in the model configuration, override the num_layers property of the model OnnxConfig to solve this" + ) + return self._config.num_layers + + @property + def num_attention_heads(self) -> int: + """ + The number of attention heads attribute retrieved from the model config. Override this for model configs where + the number of attention heads attribute is not called `num_attention_heads`. + """ + if not hasattr(self._config, "num_attention_heads"): + raise AttributeError( + "could not find the number of attention heads attribute in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this" + ) + return self._config.num_attention_heads + def generate_dummy_inputs( self, tokenizer: PreTrainedTokenizer, @@ -287,32 +332,217 @@ class OnnxConfigWithPast(OnnxConfig, ABC): is_pair: bool = False, framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: - # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX - batch_size = compute_effective_axis_dimension( - batch_size, fixed_dimension=self.default_batch_size, num_token_to_add=0 - ) - # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX - token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + # TODO: should we set seq_length = 1 when self.use_past = True? + common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) - # When use_past the caching mechanism requires inputs to be only 1 single token - fixed_sequence_length = 1 if self.use_past else self.default_sequence_length - seq_length = compute_effective_axis_dimension( - seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add - ) + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch - # Generate dummy inputs according to compute batch and sequence - dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size - return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework))) + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) - @staticmethod - def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: + if "attention_mask" in common_inputs: + common_inputs["attention_mask"] = torch.cat( + [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 + ) + + common_inputs["past_key_values"] = [] + for _ in range(self.num_layers): + common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) + + return common_inputs + + def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): + """ + Fill the input_or_ouputs mapping with past_key_values dynamic axes considering. + + Args: + inputs_or_outputs: The mapping to fill. + direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the + output mapping, this is important for axes naming. + + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + name = "past_key_values" if direction == "inputs" else "present" + for i in range(self.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + + def _flatten_past_key_values_(self, flattened_output, name, idx, t): + flattened_output[f"{name}.{idx}.key"] = t[0] + flattened_output[f"{name}.{idx}.value"] = t[1] + + def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]: + flattened_output = {} if name in ["present", "past_key_values"]: - flatten_output = {} for idx, t in enumerate(field): - flatten_output[f"{name}.{idx}.key"] = t[0] - flatten_output[f"{name}.{idx}.value"] = t[1] + self._flatten_past_key_values_(flattened_output, name, idx, t) + else: + flattened_output = super().flatten_output_collection_property(name, field) - return flatten_output + return flattened_output - return super().flatten_output_collection_property(name, field) + +class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast): + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = super(OnnxConfigWithPast, self).outputs + # Renaming the outputs axes properly. + for name, axes_names in common_outputs.items(): + sequence_name = "encoder_sequence" if "encoder" in name else "decoder_sequence" + for axis_idx, name in axes_names.items(): + if "sequence" in name: + axes_names[axis_idx] = sequence_name + # We reset the value as the order in common_outputs (OrderedDict) is lost otherwise + else: + axes_names[axis_idx] = name + if self.use_past: + self.fill_with_past_key_values_(common_outputs, direction="outputs") + + return common_outputs + + @property + def num_layers(self) -> Tuple[int]: + try: + num_layers = super().num_layers + num_layers = (num_layers, num_layers) + except AttributeError: + if hasattr(self._config, "encoder_layers") and hasattr(self._config, "decoder_layers"): + num_layers = (self._config.encoder_layers, self._config.decoder_layers) + else: + raise AttributeError( + "could not find the number of encoder and decoder layers attributes in the model configuration, override the num_layers property of the model OnnxConfig to solve this" + ) + + return num_layers + + @property + def num_attention_heads(self) -> Tuple[int]: + try: + num_attention_heads = super().num_attention_heads + num_attention_heads = (num_attention_heads, num_attention_heads) + except AttributeError: + if hasattr(self._config, "encoder_attention_heads") and hasattr(self._config, "decoder_attention_heads"): + num_attention_heads = (self._config.encoder_attention_heads, self._config.decoder_attention_heads) + else: + raise AttributeError( + "could not find the number of attention heads for the encoder and the decoder attributes in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this" + ) + return num_attention_heads + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + + encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + # Generate decoder inputs + decoder_seq_length = seq_length if not self.use_past else 1 + decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size, decoder_seq_length, is_pair, framework + ) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + common_inputs = dict(**encoder_inputs, **decoder_inputs) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch = common_inputs["input_ids"].shape[0] + encoder_seq_length = common_inputs["input_ids"].shape[1] + decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] + num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads + encoder_shape = ( + batch, + num_encoder_attention_heads, + encoder_seq_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + decoder_shape = ( + batch, + num_decoder_attention_heads, + # Not using the same length for past_key_values + decoder_seq_length + 3, + self._config.hidden_size // num_decoder_attention_heads, + ) + + common_inputs["past_key_values"] = [] + # If the number of encoder and decoder layers are present in the model configuration, both are considered + num_encoder_layers, num_decoder_layers = self.num_layers + min_num_layers = min(num_encoder_layers, num_decoder_layers) + max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers + remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" + + for _ in range(min_num_layers): + # For encoder-decoder models, past_key_values contains pre-computed values for both the encoder and the + # decoder layers, hence a tuple of 4 tensors instead of 2 + common_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + + # TODO: test this. + shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape + for _ in range(min_num_layers, max_num_layers): + common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) + + return common_inputs + + def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + name = "past_key_values" if direction == "inputs" else "present" + + # If the number of encoder and decoder layers are present in the model configuration, both are considered + num_encoder_layers, num_decoder_layers = self.num_layers + min_num_layers = min(num_encoder_layers, num_decoder_layers) + max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers + remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" + + encoder_sequence = "past_encoder_sequence" + decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence" + + for i in range(min_num_layers): + inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch", 2: decoder_sequence} + inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence} + inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence} + inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence} + + for i in range(min_num_layers, max_num_layers): + if remaining_side_name == "encoder": + axes_info = {0: "batch", 2: encoder_sequence} + else: + axes_info = {0: "batch", 2: decoder_sequence} + inputs_or_outputs[f"{name}.{i}.{remaining_side_name}.key"] = axes_info + + def _flatten_past_key_values_(self, flattened_output, name, idx, t): + flattened_output[f"{name}.{idx}.decoder.key"] = t[0] + flattened_output[f"{name}.{idx}.decoder.value"] = t[1] + flattened_output[f"{name}.{idx}.encoder.key"] = t[2] + flattened_output[f"{name}.{idx}.encoder.value"] = t[3] diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 313a7fd2e62..041e21832a0 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -191,7 +191,7 @@ def validate_model_outputs( f"{onnx_outputs_set.difference(ref_outputs_set)}" ) else: - logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set}") + logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set})") # Check the shape and values match for name, ort_value in zip(onnx_named_outputs, onnx_outputs): diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index d685af4cf77..b1d0f104539 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -1,7 +1,7 @@ from functools import partial, reduce -from typing import Callable, Tuple +from typing import Callable, Dict, Optional, Tuple, Type -from .. import is_torch_available +from .. import PretrainedConfig, is_torch_available from ..models.albert import AlbertOnnxConfig from ..models.bart import BartOnnxConfig from ..models.bert import BertOnnxConfig @@ -15,23 +15,43 @@ from ..models.mbart import MBartOnnxConfig from ..models.roberta import RobertaOnnxConfig from ..models.t5 import T5OnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig +from ..utils import logging +from .config import OnnxConfig +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_torch_available(): from transformers import PreTrainedModel from transformers.models.auto import ( AutoModel, AutoModelForCausalLM, + AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, ) +else: + logger.warning( + "The ONNX export features are only supported for PyTorch, you will not be able to export models without it." + ) -def supported_features_mapping(*supported_features, onnx_config_cls=None): - """Generates the mapping between supported features and their corresponding OnnxConfig.""" +def supported_features_mapping( + *supported_features: str, onnx_config_cls: Type[OnnxConfig] = None +) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]: + """ + Generate the mapping between supported the features and their corresponding OnnxConfig for a given model. + + Args: + *supported_features: The names of the supported features. + onnx_config_cls: The OnnxConfig class corresponding to the model. + + Returns: + The dictionary mapping a feature to an OnnxConfig constructor. + """ if onnx_config_cls is None: raise ValueError("A OnnxConfig class must be provided") @@ -47,45 +67,129 @@ def supported_features_mapping(*supported_features, onnx_config_cls=None): class FeaturesManager: - _TASKS_TO_AUTOMODELS = { - "default": AutoModel, - "causal-lm": AutoModelForCausalLM, - "seq2seq-lm": AutoModelForSeq2SeqLM, - "sequence-classification": AutoModelForSequenceClassification, - "token-classification": AutoModelForTokenClassification, - "multiple-choice": AutoModelForMultipleChoice, - "question-answering": AutoModelForQuestionAnswering, - } + if is_torch_available(): + _TASKS_TO_AUTOMODELS = { + "default": AutoModel, + "masked-lm": AutoModelForMaskedLM, + "causal-lm": AutoModelForCausalLM, + "seq2seq-lm": AutoModelForSeq2SeqLM, + "sequence-classification": AutoModelForSequenceClassification, + "token-classification": AutoModelForTokenClassification, + "multiple-choice": AutoModelForMultipleChoice, + "question-answering": AutoModelForQuestionAnswering, + } + else: + _TASKS_TO_AUTOMODELS = {} # Set of model topologies we support associated to the features supported by each topology and the factory - _SUPPORTED_MODEL_KIND = { - "albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig), - "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), - "mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig), - "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig), - "camembert": supported_features_mapping( + _SUPPORTED_MODEL_TYPE = { + "albert": supported_features_mapping( "default", + "masked-lm", + "sequence-classification", + # "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=AlbertOnnxConfig, + ), + "bart": supported_features_mapping( + "default", + "default-with-past", + "causal-lm", + "causal-lm-with-past", + "seq2seq-lm", + "seq2seq-lm-with-past", + "sequence-classification", + "question-answering", + onnx_config_cls=BartOnnxConfig, + ), + "mbart": supported_features_mapping( + "default", + "default-with-past", + "causal-lm", + "causal-lm-with-past", + "seq2seq-lm", + "seq2seq-lm-with-past", + "sequence-classification", + "question-answering", + onnx_config_cls=MBartOnnxConfig, + ), + "bert": supported_features_mapping( + "default", + "masked-lm", "causal-lm", "sequence-classification", + # "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=BertOnnxConfig, + ), + "camembert": supported_features_mapping( + "default", + "masked-lm", + "causal-lm", + "sequence-classification", + # "multiple-choice", "token-classification", "question-answering", onnx_config_cls=CamembertOnnxConfig, ), - "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig), - "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), - "longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig), - "roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig), + "distilbert": supported_features_mapping( + "default", + "masked-lm", + "sequence-classification", + # "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=DistilBertOnnxConfig, + ), + "longformer": supported_features_mapping( + "default", + "masked-lm", + "sequence-classification", + # "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=LongformerOnnxConfig, + ), + "roberta": supported_features_mapping( + "default", + "masked-lm", + "causal-lm", + "sequence-classification", + # "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=RobertaOnnxConfig, + ), "t5": supported_features_mapping( "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig ), - "xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig), - "gpt-neo": supported_features_mapping( + "xlm-roberta": supported_features_mapping( "default", + "masked-lm", "causal-lm", "sequence-classification", + # "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=XLMRobertaOnnxConfig, + ), + "gpt2": supported_features_mapping( + "default", "default-with-past", + "causal-lm", "causal-lm-with-past", - "sequence-classification-with-past", + "sequence-classification", + "token-classification", + onnx_config_cls=GPT2OnnxConfig, + ), + "gpt-neo": supported_features_mapping( + "default", + "default-with-past", + "causal-lm", + "causal-lm-with-past", + "sequence-classification", onnx_config_cls=GPTNeoOnnxConfig, ), "layoutlm": supported_features_mapping( @@ -97,23 +201,46 @@ class FeaturesManager: ), } - AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values()))) + AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) + + @staticmethod + def get_supported_features_for_model_type( + model_type: str, model_name: Optional[str] = None + ) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]: + """ + Try to retrieve the feature -> OnnxConfig constructor map from the model type. + + Args: + model_type: The model type to retrieve the supported features for. + model_name: The name attribute of the model object, only used for the exception message. + + Returns: + The dictionary mapping each feature to a corresponding OnnxConfig constructor. + """ + model_type = model_type.lower() + if model_type not in FeaturesManager._SUPPORTED_MODEL_TYPE: + model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type + raise KeyError( + f"{model_type_and_model_name} is not supported yet. " + f"Only {list(FeaturesManager._SUPPORTED_MODEL_TYPE.keys())} are supported. " + f"If you want to support {model_type} please propose a PR or open up an issue." + ) + return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type] @staticmethod def feature_to_task(feature: str) -> str: return feature.replace("-with-past", "") @staticmethod - def get_model_from_feature(feature: str, model: str): + def get_model_class_for_feature(feature: str) -> Type: """ - Attempt to retrieve a model from a model's name and the feature to be enabled. + Attempt to retrieve an AutoModel class from a feature name. Args: - feature: The feature required - model: The name of the model to export + feature: The feature required. Returns: - + The AutoModel class corresponding to the feature. """ task = FeaturesManager.feature_to_task(feature) if task not in FeaturesManager._TASKS_TO_AUTOMODELS: @@ -121,38 +248,43 @@ class FeaturesManager: f"Unknown task: {feature}. " f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" ) + return FeaturesManager._TASKS_TO_AUTOMODELS[task] - return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model) + def get_model_from_feature(feature: str, model: str) -> PreTrainedModel: + """ + Attempt to retrieve a model from a model's name and the feature to be enabled. + + Args: + feature: The feature required. + model: The name of the model to export. + + Returns: + The instance of the model. + + """ + model_class = FeaturesManager.get_model_class_for_feature(feature) + return model_class.from_pretrained(model) @staticmethod def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]: """ - Check whether or not the model has the requested features + Check whether or not the model has the requested features. Args: - model: The model to export - feature: The name of the feature to check if it is available + model: The model to export. + feature: The name of the feature to check if it is available. Returns: - (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties + (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties. """ model_type = model.config.model_type.replace("_", "-") model_name = getattr(model, "name", "") - model_name = f"({model_name})" if model_name else "" - if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND: - raise KeyError( - f"{model.config.model_type} ({model_name}) is not supported yet. " - f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. " - f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue." - ) - - # Look for the features - model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type] + model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name) if feature not in model_features: raise ValueError( f"{model.config.model_type} doesn't support feature {feature}. " - f"Supported values are: {list(model_features.keys())}" + f"Supported values are: {model_features}" ) - return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature] + return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature] diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index 861b781442e..27cc10e9fac 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -3,33 +3,8 @@ from tempfile import NamedTemporaryFile from unittest import TestCase from unittest.mock import patch -from transformers import ( # LongformerConfig,; T5Config, - AlbertConfig, - AutoTokenizer, - BartConfig, - DistilBertConfig, - GPT2Config, - GPTNeoConfig, - LayoutLMConfig, - MBartConfig, - RobertaConfig, - XLMRobertaConfig, - is_torch_available, -) -from transformers.models.albert import AlbertOnnxConfig -from transformers.models.bart import BartOnnxConfig -from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig -from transformers.models.distilbert import DistilBertOnnxConfig - -# from transformers.models.longformer import LongformerOnnxConfig -from transformers.models.gpt2 import GPT2OnnxConfig -from transformers.models.gpt_neo import GPTNeoOnnxConfig -from transformers.models.layoutlm import LayoutLMOnnxConfig -from transformers.models.mbart import MBartOnnxConfig -from transformers.models.roberta import RobertaOnnxConfig - -# from transformers.models.t5 import T5OnnxConfig -from transformers.models.xlm_roberta import XLMRobertaOnnxConfig +from parameterized import parameterized +from transformers import AutoConfig, AutoTokenizer, is_torch_available from transformers.onnx import ( EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, @@ -37,7 +12,12 @@ from transformers.onnx import ( export, validate_model_outputs, ) -from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast +from transformers.onnx.config import OnnxConfigWithPast + + +if is_torch_available(): + from transformers.onnx.features import FeaturesManager + from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size from transformers.testing_utils import require_onnx, require_torch, slow @@ -139,11 +119,12 @@ class OnnxConfigWithPastTestCaseV2(TestCase): Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX) """ - SUPPORTED_WITH_PAST_CONFIGS = { - ("BART", BartConfig), - ("GPT2", GPT2Config), - # ("T5", T5Config) - } + SUPPORTED_WITH_PAST_CONFIGS = {} + # SUPPORTED_WITH_PAST_CONFIGS = { + # ("BART", BartConfig), + # ("GPT2", GPT2Config), + # # ("T5", T5Config) + # } @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) def test_use_past(self): @@ -187,40 +168,41 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ) -if is_torch_available(): - from transformers import ( # T5Model, - AlbertModel, - BartModel, - BertModel, - DistilBertModel, - GPT2Model, - GPTNeoModel, - LayoutLMModel, - MBartModel, - RobertaModel, - XLMRobertaModel, - ) +PYTORCH_EXPORT_MODELS = { + ("albert", "hf-internal-testing/tiny-albert"), + ("bert", "bert-base-cased"), + ("camembert", "camembert-base"), + ("distilbert", "distilbert-base-cased"), + # ("longFormer", "longformer-base-4096"), + ("roberta", "roberta-base"), + ("xlm-roberta", "xlm-roberta-base"), + ("layoutlm", "microsoft/layoutlm-base-uncased"), +} - PYTORCH_EXPORT_DEFAULT_MODELS = { - ("ALBERT", "hf-internal-testing/tiny-albert", AlbertModel, AlbertConfig, AlbertOnnxConfig), - ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig), - ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig), - ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig), - ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig), - ("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig), - # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), - ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), - ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), - ("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig), - ("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig), - # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), - } +PYTORCH_EXPORT_WITH_PAST_MODELS = { + ("gpt2", "gpt2"), + ("gpt-neo", "EleutherAI/gpt-neo-125M"), +} - PYTORCH_EXPORT_WITH_PAST_MODELS = { - # ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig), - # ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig), - # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig) - } +PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { + ("bart", "facebook/bart-base"), + ("mbart", "sshleifer/tiny-mbart"), + ("t5", "t5-small"), +} + + +def _get_models_to_test(export_models_list): + models_to_test = [] + if not is_torch_available(): + # Returning some dummy test that should not be ever called because of the @require_torch decorator. + # The reason for not returning an empty list is because parameterized.expand complains when it's empty. + return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)] + for (name, model) in export_models_list: + for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type( + name + ).items(): + models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor)) + return sorted(models_to_test) class OnnxExportTestCaseV2(TestCase): @@ -228,52 +210,52 @@ class OnnxExportTestCaseV2(TestCase): Integration tests ensuring supported models are correctly exported """ - @slow - @require_torch - def test_pytorch_export_default(self): + def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): from transformers.onnx import export - for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS: - with self.subTest(name): - self.assertTrue(hasattr(onnx_config_class, "from_model_config")) + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model) - model = model_class(config_class.from_pretrained(model)) - onnx_config = onnx_config_class.from_model_config(model.config) + # Useful for causal lm models that do not use pad tokens. + if not getattr(config, "pad_token_id", None): + config.pad_token_id = tokenizer.eos_token_id - with NamedTemporaryFile("w") as output: - onnx_inputs, onnx_outputs = export( - tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name) - ) + model_class = FeaturesManager.get_model_class_for_feature(feature) + model = model_class.from_config(config) + onnx_config = onnx_config_class_constructor(model.config) - try: - validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5) - except ValueError as ve: - self.fail(f"{name} -> {ve}") - - @slow - @require_torch - def test_pytorch_export_with_past(self): - from transformers.onnx import export - - for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS: - with self.subTest(name): - self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()") - - tokenizer = AutoTokenizer.from_pretrained(model) - model = model_class(config_class()) - onnx_config = onnx_config_class.with_past(model.config) - - self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.") - self.assertTrue( - onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()" + with NamedTemporaryFile("w") as output: + try: + onnx_inputs, onnx_outputs = export( + tokenizer, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name) ) + validate_model_outputs( + onnx_config, + tokenizer, + model, + Path(output.name), + onnx_outputs, + onnx_config.atol_for_validation, + ) + except (RuntimeError, ValueError) as e: + self.fail(f"{name}, {feature} -> {e}") - with NamedTemporaryFile("w") as output: - output = Path(output.name) - onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS)) + @slow + @require_torch + def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): + self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) - try: - validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5) - except ValueError as ve: - self.fail(f"{name} -> {ve}") + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS)) + @slow + @require_torch + def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor): + self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) + + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS)) + @slow + @require_torch + def test_pytorch_export_seq2seq_with_past( + self, test_name, name, model_name, feature, onnx_config_class_constructor + ): + self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)