mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
This reverts commit 0c70f145d1
.
This commit is contained in:
parent
0c70f145d1
commit
0f4e39c559
@ -167,3 +167,7 @@ class AlbertOnnxConfig(OnnxConfig):
|
||||
("token_type_ids", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||
|
@ -15,12 +15,10 @@
|
||||
""" BART model configuration """
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Mapping
|
||||
|
||||
from ... import PreTrainedTokenizer
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import TensorType, is_torch_available
|
||||
from ...onnx import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
|
||||
from ...onnx import OnnxConfigWithPast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -182,174 +180,30 @@ class BartConfig(PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
class BartOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||
class BartOnnxConfig(OnnxConfigWithPast):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
if self.use_past:
|
||||
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
elif self.task == "causal-lm":
|
||||
# TODO: figure this case out.
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
]
|
||||
)
|
||||
if self.use_past:
|
||||
num_encoder_layers, _ = self.num_layers
|
||||
for i in range(num_encoder_layers):
|
||||
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
else:
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
|
||||
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
return common_inputs
|
||||
return OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
common_outputs = super().outputs
|
||||
if self.use_past:
|
||||
return OrderedDict(
|
||||
[
|
||||
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
("past_keys", {0: "batch", 2: "sequence"}),
|
||||
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_outputs = super(OnnxConfigWithPast, self).outputs
|
||||
if self.use_past:
|
||||
num_encoder_layers, _ = self.num_layers
|
||||
for i in range(num_encoder_layers):
|
||||
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
return common_outputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
# Generate decoder inputs
|
||||
decoder_seq_length = seq_length if not self.use_past else 1
|
||||
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
||||
)
|
||||
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
||||
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
||||
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
|
||||
encoder_shape = (
|
||||
batch,
|
||||
num_encoder_attention_heads,
|
||||
encoder_seq_length,
|
||||
self._config.hidden_size // num_encoder_attention_heads,
|
||||
)
|
||||
decoder_past_length = decoder_seq_length + 3
|
||||
decoder_shape = (
|
||||
batch,
|
||||
num_decoder_attention_heads,
|
||||
decoder_past_length,
|
||||
self._config.hidden_size // num_decoder_attention_heads,
|
||||
)
|
||||
|
||||
common_inputs["decoder_attention_mask"] = torch.cat(
|
||||
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
|
||||
)
|
||||
|
||||
common_inputs["past_key_values"] = []
|
||||
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
||||
num_encoder_layers, num_decoder_layers = self.num_layers
|
||||
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
||||
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
||||
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
||||
|
||||
for _ in range(min_num_layers):
|
||||
common_inputs["past_key_values"].append(
|
||||
(
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: test this.
|
||||
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
||||
for _ in range(min_num_layers, max_num_layers):
|
||||
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
|
||||
|
||||
elif self.task == "causal-lm":
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
num_encoder_layers, _ = self.num_layers
|
||||
num_encoder_attention_heads, _ = self.num_attention_heads
|
||||
past_shape = (
|
||||
batch,
|
||||
num_encoder_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // num_encoder_attention_heads,
|
||||
)
|
||||
|
||||
common_inputs["attention_mask"] = torch.cat(
|
||||
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
)
|
||||
common_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
|
||||
]
|
||||
else:
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
return common_inputs
|
||||
|
||||
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
|
||||
else:
|
||||
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
|
||||
flattened_output, name, idx, t
|
||||
return OrderedDict(
|
||||
[
|
||||
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
@ -169,3 +169,7 @@ class BertOnnxConfig(OnnxConfig):
|
||||
("token_type_ids", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||
|
@ -142,3 +142,7 @@ class DistilBertOnnxConfig(OnnxConfig):
|
||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
|
||||
|
@ -15,12 +15,12 @@
|
||||
# limitations under the License.
|
||||
""" OpenAI GPT-2 configuration """
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Mapping, Optional
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer, TensorType, is_torch_available
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
||||
from ...onnx import OnnxConfigWithPast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -194,36 +194,29 @@ class GPT2Config(PretrainedConfig):
|
||||
|
||||
|
||||
class GPT2OnnxConfig(OnnxConfigWithPast):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
task: str = "default",
|
||||
patching_specs: List[PatchingSpec] = None,
|
||||
use_past: bool = False,
|
||||
):
|
||||
super().__init__(config, task=task, patching_specs=patching_specs)
|
||||
if not getattr(self._config, "pad_token_id", None):
|
||||
# TODO: how to do that better?
|
||||
self._config.pad_token_id = 0
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch"}})
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
||||
for i in range(self._config.n_layer * 2):
|
||||
common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"}
|
||||
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
else:
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self._config.n_layer
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}})
|
||||
if self.use_past:
|
||||
for i in range(self._config.n_layer * 2):
|
||||
common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"}
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self._config.n_head
|
||||
return common_outputs
|
||||
|
||||
return common_outputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
@ -233,9 +226,7 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||
|
||||
# We need to order the input in the way they appears in the forward()
|
||||
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||
@ -247,27 +238,14 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
|
||||
else:
|
||||
import torch
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
past_shape = (
|
||||
batch,
|
||||
self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
batch = common_inputs["input_ids"].shape[0]
|
||||
ordered_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||
(
|
||||
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
|
||||
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
|
||||
)
|
||||
for _ in range(self._config.n_layer)
|
||||
]
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 13
|
||||
|
@ -15,7 +15,7 @@
|
||||
""" GPT Neo model configuration """
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Dict, Iterable, Mapping, Optional
|
||||
|
||||
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
@ -212,7 +212,10 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
for i in range(self._config.num_layers):
|
||||
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence"}
|
||||
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
@ -220,8 +223,16 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self._config.num_heads
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = super().outputs
|
||||
if self.use_past:
|
||||
for i in range(self._config.num_layers):
|
||||
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
|
||||
return common_outputs
|
||||
|
||||
return common_outputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
@ -231,10 +242,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||
|
||||
# We need to order the input in the way they appears in the forward()
|
||||
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||
@ -246,27 +254,28 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
else:
|
||||
import torch
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
past_shape = (
|
||||
batch,
|
||||
self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
batch = common_inputs["input_ids"].shape[0]
|
||||
past_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
|
||||
ordered_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self._config.num_layers)
|
||||
]
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
[ordered_inputs["attention_mask"], torch.ones(batch, 1)], dim=1
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 13
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
if name in ["present", "past_key_values"]:
|
||||
flatten_output = {}
|
||||
for idx, t in enumerate(field):
|
||||
flatten_output[f"{name}.{idx}.key"] = t[0]
|
||||
flatten_output[f"{name}.{idx}.value"] = t[1]
|
||||
|
||||
return flatten_output
|
||||
|
||||
return super().flatten_output_collection_property(name, field)
|
||||
|
@ -14,12 +14,11 @@
|
||||
# limitations under the License.
|
||||
""" MBART model configuration """
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Mapping
|
||||
|
||||
from transformers.onnx import OnnxConfigWithPast
|
||||
|
||||
from ... import PreTrainedTokenizer
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import TensorType, is_torch_available
|
||||
from ...onnx import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -166,175 +165,30 @@ class MBartConfig(PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->MBart
|
||||
class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||
class MBartOnnxConfig(OnnxConfigWithPast):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
if self.use_past:
|
||||
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
elif self.task == "causal-lm":
|
||||
# TODO: figure this case out.
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
]
|
||||
)
|
||||
if self.use_past:
|
||||
num_encoder_layers, _ = self.num_layers
|
||||
for i in range(num_encoder_layers):
|
||||
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
else:
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
|
||||
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
return common_inputs
|
||||
return OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
common_outputs = super().outputs
|
||||
if self.use_past:
|
||||
return OrderedDict(
|
||||
[
|
||||
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
("past_keys", {0: "batch", 2: "sequence"}),
|
||||
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_outputs = super(OnnxConfigWithPast, self).outputs
|
||||
if self.use_past:
|
||||
num_encoder_layers, _ = self.num_layers
|
||||
for i in range(num_encoder_layers):
|
||||
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
return common_outputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
# Generate decoder inputs
|
||||
decoder_seq_length = seq_length if not self.use_past else 1
|
||||
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
||||
)
|
||||
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
||||
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
||||
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
|
||||
encoder_shape = (
|
||||
batch,
|
||||
num_encoder_attention_heads,
|
||||
encoder_seq_length,
|
||||
self._config.hidden_size // num_encoder_attention_heads,
|
||||
)
|
||||
decoder_past_length = decoder_seq_length + 3
|
||||
decoder_shape = (
|
||||
batch,
|
||||
num_decoder_attention_heads,
|
||||
decoder_past_length,
|
||||
self._config.hidden_size // num_decoder_attention_heads,
|
||||
)
|
||||
|
||||
common_inputs["decoder_attention_mask"] = torch.cat(
|
||||
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
|
||||
)
|
||||
|
||||
common_inputs["past_key_values"] = []
|
||||
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
||||
num_encoder_layers, num_decoder_layers = self.num_layers
|
||||
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
||||
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
||||
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
||||
|
||||
for _ in range(min_num_layers):
|
||||
common_inputs["past_key_values"].append(
|
||||
(
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: test this.
|
||||
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
||||
for _ in range(min_num_layers, max_num_layers):
|
||||
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
|
||||
|
||||
elif self.task == "causal-lm":
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
num_encoder_layers, _ = self.num_layers
|
||||
num_encoder_attention_heads, _ = self.num_attention_heads
|
||||
past_shape = (
|
||||
batch,
|
||||
num_encoder_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // num_encoder_attention_heads,
|
||||
)
|
||||
|
||||
common_inputs["attention_mask"] = torch.cat(
|
||||
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
)
|
||||
common_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
|
||||
]
|
||||
else:
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
return common_inputs
|
||||
|
||||
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
|
||||
else:
|
||||
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
|
||||
flattened_output, name, idx, t
|
||||
return OrderedDict(
|
||||
[
|
||||
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
@ -76,3 +76,7 @@ class RobertaOnnxConfig(OnnxConfig):
|
||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||
|
@ -13,11 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" T5 model configuration """
|
||||
from typing import Mapping
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Iterable, Mapping, Optional
|
||||
|
||||
# from ... import is_torch_available
|
||||
from transformers import PreTrainedTokenizer, TensorType
|
||||
|
||||
from ... import is_torch_available
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxSeq2SeqConfigWithPast
|
||||
from ...onnx import OnnxConfigWithPast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -122,26 +125,101 @@ class T5Config(PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
class T5OnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||
class T5OnnxConfig(OnnxConfigWithPast):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = {
|
||||
"input_ids": {0: "batch", 1: "encoder_sequence"},
|
||||
"attention_mask": {0: "batch", 1: "encoder_sequence"},
|
||||
}
|
||||
if self.use_past:
|
||||
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
|
||||
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
("decoder_input_ids", {0: "batch"}),
|
||||
("decoder_attention_mask", {0: "batch"}),
|
||||
]
|
||||
)
|
||||
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
for i in range(0, self._config.num_layers):
|
||||
common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 13
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = super().outputs
|
||||
|
||||
if "last_hidden_state" in common_outputs:
|
||||
common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}
|
||||
|
||||
if self.use_past:
|
||||
for i in range(self._config.num_layers):
|
||||
common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
|
||||
common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
|
||||
common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
|
||||
common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}
|
||||
|
||||
if self.task == "default":
|
||||
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}
|
||||
|
||||
return common_outputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
# Generate encoder inputs
|
||||
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||
|
||||
# Generate decoder inputs
|
||||
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
|
||||
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||
|
||||
ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
batch = encoder_inputs["input_ids"].shape[0]
|
||||
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
|
||||
encoder_shape = (
|
||||
batch,
|
||||
self._config.num_heads,
|
||||
encoder_seq_length,
|
||||
self._config.hidden_size // self._config.num_heads,
|
||||
)
|
||||
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
|
||||
|
||||
ordered_inputs["past_key_values"] = []
|
||||
for _ in range(self._config.num_layers):
|
||||
ordered_inputs["past_key_values"].append(
|
||||
(
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
)
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
if name in ["present", "past_key_values"]:
|
||||
flatten_output = {}
|
||||
for idx, t in enumerate(field):
|
||||
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
|
||||
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
|
||||
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
|
||||
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]
|
||||
|
||||
return flatten_output
|
||||
|
||||
return super().flatten_output_collection_property(name, field)
|
||||
|
@ -53,3 +53,7 @@ class XLMRobertaOnnxConfig(OnnxConfig):
|
||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||
|
@ -13,12 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
OnnxConfigWithPast,
|
||||
OnnxSeq2SeqConfigWithPast,
|
||||
PatchingSpec,
|
||||
)
|
||||
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec
|
||||
from .convert import export, validate_model_outputs
|
||||
from .utils import ParameterFormat, compute_serialized_parameters_size
|
||||
|
@ -32,10 +32,10 @@ def main():
|
||||
help="Export the model with some additional feature.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opset", type=int, default=None, help="ONNX opset version to export the model with (default 12)."
|
||||
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model."
|
||||
"--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model."
|
||||
)
|
||||
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
|
||||
|
||||
@ -53,9 +53,6 @@ def main():
|
||||
onnx_config = model_onnx_config(model.config)
|
||||
|
||||
# Ensure the requested opset is sufficient
|
||||
if args.opset is None:
|
||||
args.opset = onnx_config.default_onnx_opset
|
||||
|
||||
if args.opset < onnx_config.default_onnx_opset:
|
||||
raise ValueError(
|
||||
f"Opset {args.opset} is not sufficient to export {model_kind}. "
|
||||
@ -64,9 +61,6 @@ def main():
|
||||
|
||||
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)
|
||||
|
||||
if args.atol is None:
|
||||
args.atol = onnx_config.atol_for_validation
|
||||
|
||||
validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
|
||||
logger.info(f"All good, model saved at: {args.output.as_posix()}")
|
||||
|
||||
|
@ -14,9 +14,9 @@
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType, is_torch_available
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||
|
||||
from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
|
||||
@ -58,7 +58,6 @@ class OnnxConfig(ABC):
|
||||
|
||||
_TASKS_TO_COMMON_OUTPUTS = {
|
||||
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
||||
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
|
||||
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
||||
@ -120,8 +119,7 @@ class OnnxConfig(ABC):
|
||||
Returns:
|
||||
For each output: its name associated to the axes symbolic name and the axis position within the tensor
|
||||
"""
|
||||
common_outputs = self._TASKS_TO_COMMON_OUTPUTS[self.task]
|
||||
return common_outputs
|
||||
return self._TASKS_TO_COMMON_OUTPUTS[self.task]
|
||||
|
||||
@property
|
||||
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||
@ -167,16 +165,6 @@ class OnnxConfig(ABC):
|
||||
"""
|
||||
return DEFAULT_ONNX_OPSET
|
||||
|
||||
@property
|
||||
def atol_for_validation(self) -> float:
|
||||
"""
|
||||
What absolute tolerance value to use during model conversion validation.
|
||||
|
||||
Returns:
|
||||
Float absolute tolerance value.
|
||||
"""
|
||||
return 1e-5
|
||||
|
||||
@staticmethod
|
||||
def use_external_data_format(num_parameters: int) -> bool:
|
||||
"""
|
||||
@ -241,8 +229,8 @@ class OnnxConfig(ABC):
|
||||
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
|
||||
setattr(spec.o, spec.name, orig_op)
|
||||
|
||||
@classmethod
|
||||
def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Flatten any potential nested structure expanding the name of the field with the index of the element within the
|
||||
structure.
|
||||
@ -284,14 +272,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
"""
|
||||
return cls(config, task=task, use_past=True)
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = super().outputs
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_outputs, direction="outputs")
|
||||
|
||||
return common_outputs
|
||||
|
||||
@property
|
||||
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||
if hasattr(self._config, "use_cache"):
|
||||
@ -299,30 +279,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
"""
|
||||
The number of layers attribute retrieved from the model config. Override this for model configs where the
|
||||
number of layers attribute is not called `num_layers`.
|
||||
"""
|
||||
if not hasattr(self._config, "num_layers"):
|
||||
raise AttributeError(
|
||||
"could not find the number of layers attribute in the model configuration, override the num_layers property of the model OnnxConfig to solve this"
|
||||
)
|
||||
return self._config.num_layers
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
"""
|
||||
The number of attention heads attribute retrieved from the model config. Override this for model configs where
|
||||
the number of attention heads attribute is not called `num_attention_heads`.
|
||||
"""
|
||||
if not hasattr(self._config, "num_attention_heads"):
|
||||
raise AttributeError(
|
||||
"could not find the number of attention heads attribute in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this"
|
||||
)
|
||||
return self._config.num_attention_heads
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
@ -331,217 +287,32 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||
batch_size = compute_effective_axis_dimension(
|
||||
batch_size, fixed_dimension=self.default_batch_size, num_token_to_add=0
|
||||
)
|
||||
|
||||
# TODO: should we set seq_length = 1 when self.use_past = True?
|
||||
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
||||
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
# When use_past the caching mechanism requires inputs to be only 1 single token
|
||||
fixed_sequence_length = 1 if self.use_past else self.default_sequence_length
|
||||
seq_length = compute_effective_axis_dimension(
|
||||
seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add
|
||||
)
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
shape = (
|
||||
batch,
|
||||
self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
# Generate dummy inputs according to compute batch and sequence
|
||||
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
|
||||
|
||||
if "attention_mask" in common_inputs:
|
||||
common_inputs["attention_mask"] = torch.cat(
|
||||
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
)
|
||||
|
||||
common_inputs["past_key_values"] = []
|
||||
for _ in range(self.num_layers):
|
||||
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
|
||||
|
||||
return common_inputs
|
||||
|
||||
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
|
||||
"""
|
||||
Fill the input_or_ouputs mapping with past_key_values dynamic axes considering.
|
||||
|
||||
Args:
|
||||
inputs_or_outputs: The mapping to fill.
|
||||
direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the
|
||||
output mapping, this is important for axes naming.
|
||||
|
||||
"""
|
||||
if direction not in ["inputs", "outputs"]:
|
||||
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
|
||||
|
||||
name = "past_key_values" if direction == "inputs" else "present"
|
||||
for i in range(self.num_layers):
|
||||
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
|
||||
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
||||
flattened_output[f"{name}.{idx}.key"] = t[0]
|
||||
flattened_output[f"{name}.{idx}.value"] = t[1]
|
||||
|
||||
def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
flattened_output = {}
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
if name in ["present", "past_key_values"]:
|
||||
flatten_output = {}
|
||||
for idx, t in enumerate(field):
|
||||
self._flatten_past_key_values_(flattened_output, name, idx, t)
|
||||
else:
|
||||
flattened_output = super().flatten_output_collection_property(name, field)
|
||||
flatten_output[f"{name}.{idx}.key"] = t[0]
|
||||
flatten_output[f"{name}.{idx}.value"] = t[1]
|
||||
|
||||
return flattened_output
|
||||
return flatten_output
|
||||
|
||||
|
||||
class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = self._TASKS_TO_COMMON_OUTPUTS[self.task]
|
||||
# Renaming the outputs axes properly.
|
||||
for name, axes_names in common_outputs.items():
|
||||
sequence_name = "encoder_sequence" if "encoder" in name else "decoder_sequence"
|
||||
for axis_idx, name in axes_names.items():
|
||||
if "sequence" in name:
|
||||
axes_names[axis_idx] = sequence_name
|
||||
# We reset the value as the order in common_outputs (OrderedDict) is lost otherwise
|
||||
else:
|
||||
axes_names[axis_idx] = name
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_outputs, direction="outputs")
|
||||
|
||||
return common_outputs
|
||||
|
||||
@property
|
||||
def num_layers(self) -> Tuple[int]:
|
||||
try:
|
||||
num_layers = super().num_layers
|
||||
num_layers = (num_layers, num_layers)
|
||||
except AttributeError:
|
||||
if hasattr(self._config, "encoder_layers") and hasattr(self._config, "decoder_layers"):
|
||||
num_layers = (self._config.encoder_layers, self._config.decoder_layers)
|
||||
else:
|
||||
raise AttributeError(
|
||||
"could not find the number of encoder and decoder layers attributes in the model configuration, override the num_layers property of the model OnnxConfig to solve this"
|
||||
)
|
||||
|
||||
return num_layers
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> Tuple[int]:
|
||||
try:
|
||||
num_attention_heads = super().num_attention_heads
|
||||
num_attention_heads = (num_attention_heads, num_attention_heads)
|
||||
except AttributeError:
|
||||
if hasattr(self._config, "encoder_attention_heads") and hasattr(self._config, "decoder_attention_heads"):
|
||||
num_attention_heads = (self._config.encoder_attention_heads, self._config.decoder_attention_heads)
|
||||
else:
|
||||
raise AttributeError(
|
||||
"could not find the number of attention heads for the encoder and the decoder attributes in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this"
|
||||
)
|
||||
return num_attention_heads
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
# Generate decoder inputs
|
||||
decoder_seq_length = seq_length if not self.use_past else 1
|
||||
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
||||
)
|
||||
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
batch = common_inputs["input_ids"].shape[0]
|
||||
encoder_seq_length = common_inputs["input_ids"].shape[1]
|
||||
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
||||
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
|
||||
encoder_shape = (
|
||||
batch,
|
||||
num_encoder_attention_heads,
|
||||
encoder_seq_length,
|
||||
self._config.hidden_size // num_encoder_attention_heads,
|
||||
)
|
||||
decoder_shape = (
|
||||
batch,
|
||||
num_decoder_attention_heads,
|
||||
# Not using the same length for past_key_values
|
||||
decoder_seq_length + 3,
|
||||
self._config.hidden_size // num_decoder_attention_heads,
|
||||
)
|
||||
|
||||
common_inputs["past_key_values"] = []
|
||||
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
||||
num_encoder_layers, num_decoder_layers = self.num_layers
|
||||
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
||||
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
||||
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
||||
|
||||
for _ in range(min_num_layers):
|
||||
# For encoder-decoder models, past_key_values contains pre-computed values for both the encoder and the
|
||||
# decoder layers, hence a tuple of 4 tensors instead of 2
|
||||
common_inputs["past_key_values"].append(
|
||||
(
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: test this.
|
||||
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
||||
for _ in range(min_num_layers, max_num_layers):
|
||||
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
|
||||
|
||||
return common_inputs
|
||||
|
||||
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
|
||||
if direction not in ["inputs", "outputs"]:
|
||||
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
|
||||
|
||||
name = "past_key_values" if direction == "inputs" else "present"
|
||||
|
||||
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
||||
num_encoder_layers, num_decoder_layers = self.num_layers
|
||||
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
||||
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
||||
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
||||
|
||||
encoder_sequence = "past_encoder_sequence"
|
||||
decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence"
|
||||
|
||||
for i in range(min_num_layers):
|
||||
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch", 2: decoder_sequence}
|
||||
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence}
|
||||
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence}
|
||||
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence}
|
||||
|
||||
for i in range(min_num_layers, max_num_layers):
|
||||
if remaining_side_name == "encoder":
|
||||
axes_info = {0: "batch", 2: encoder_sequence}
|
||||
else:
|
||||
axes_info = {0: "batch", 2: decoder_sequence}
|
||||
inputs_or_outputs[f"{name}.{i}.{remaining_side_name}.key"] = axes_info
|
||||
|
||||
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
||||
flattened_output[f"{name}.{idx}.decoder.key"] = t[0]
|
||||
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
|
||||
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
|
||||
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]
|
||||
return super().flatten_output_collection_property(name, field)
|
||||
|
@ -191,7 +191,7 @@ def validate_model_outputs(
|
||||
f"{onnx_outputs_set.difference(ref_outputs_set)}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set})")
|
||||
logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set}")
|
||||
|
||||
# Check the shape and values match
|
||||
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
||||
|
@ -1,7 +1,7 @@
|
||||
from functools import partial, reduce
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from .. import PretrainedConfig, is_torch_available
|
||||
from .. import is_torch_available
|
||||
from ..models.albert import AlbertOnnxConfig
|
||||
from ..models.bart import BartOnnxConfig
|
||||
from ..models.bert import BertOnnxConfig
|
||||
@ -15,7 +15,6 @@ from ..models.mbart import MBartOnnxConfig
|
||||
from ..models.roberta import RobertaOnnxConfig
|
||||
from ..models.t5 import T5OnnxConfig
|
||||
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
from .config import OnnxConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -23,7 +22,6 @@ if is_torch_available():
|
||||
from transformers.models.auto import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
@ -32,19 +30,8 @@ if is_torch_available():
|
||||
)
|
||||
|
||||
|
||||
def supported_features_mapping(
|
||||
*supported_features: str, onnx_config_cls: Type[OnnxConfig] = None
|
||||
) -> Dict[str, Callable[[PretrainedConfig, str], OnnxConfig]]:
|
||||
"""
|
||||
Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
|
||||
|
||||
Args:
|
||||
*supported_features: The names of the supported features.
|
||||
onnx_config_cls: The OnnxConfig class corresponding to the model.
|
||||
|
||||
Returns:
|
||||
The dictionary mapping a feature to an OnnxConfig constructor.
|
||||
"""
|
||||
def supported_features_mapping(*supported_features, onnx_config_cls=None):
|
||||
"""Generates the mapping between supported features and their corresponding OnnxConfig."""
|
||||
if onnx_config_cls is None:
|
||||
raise ValueError("A OnnxConfig class must be provided")
|
||||
|
||||
@ -62,7 +49,6 @@ def supported_features_mapping(
|
||||
class FeaturesManager:
|
||||
_TASKS_TO_AUTOMODELS = {
|
||||
"default": AutoModel,
|
||||
"masked-lm": AutoModelForMaskedLM,
|
||||
"causal-lm": AutoModelForCausalLM,
|
||||
"seq2seq-lm": AutoModelForSeq2SeqLM,
|
||||
"sequence-classification": AutoModelForSequenceClassification,
|
||||
@ -72,110 +58,27 @@ class FeaturesManager:
|
||||
}
|
||||
|
||||
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||
_SUPPORTED_MODEL_TYPE = {
|
||||
"albert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
# "multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=AlbertOnnxConfig,
|
||||
),
|
||||
"bart": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
"sequence-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=BartOnnxConfig,
|
||||
),
|
||||
"mbart": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
"sequence-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=MBartOnnxConfig,
|
||||
),
|
||||
"bert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
# "multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=BertOnnxConfig,
|
||||
),
|
||||
_SUPPORTED_MODEL_KIND = {
|
||||
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
|
||||
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
|
||||
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
|
||||
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
|
||||
"camembert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
# "multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=CamembertOnnxConfig,
|
||||
),
|
||||
"distilbert": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
# "multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=DistilBertOnnxConfig,
|
||||
),
|
||||
"longformer": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"sequence-classification",
|
||||
# "multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=LongformerOnnxConfig,
|
||||
),
|
||||
"roberta": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
# "multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=RobertaOnnxConfig,
|
||||
),
|
||||
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
|
||||
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
||||
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
|
||||
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
|
||||
"t5": supported_features_mapping(
|
||||
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
|
||||
),
|
||||
"xlm-roberta": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
# "multiple-choice",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=XLMRobertaOnnxConfig,
|
||||
),
|
||||
"gpt2": supported_features_mapping(
|
||||
"default",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"token-classification",
|
||||
"default-with-past",
|
||||
"causal-lm-with-past",
|
||||
"sequence-classification-with-past",
|
||||
"token-classification-with-past",
|
||||
onnx_config_cls=GPT2OnnxConfig,
|
||||
),
|
||||
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
|
||||
"gpt-neo": supported_features_mapping(
|
||||
"default",
|
||||
"causal-lm",
|
||||
@ -194,46 +97,23 @@ class FeaturesManager:
|
||||
),
|
||||
}
|
||||
|
||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
||||
|
||||
@staticmethod
|
||||
def get_supported_features_for_model_type(
|
||||
model_type: str, model_name: Optional[str] = None
|
||||
) -> Dict[str, Callable[[PretrainedConfig, str], OnnxConfig]]:
|
||||
"""
|
||||
Try to retrieve the feature -> OnnxConfig constructor map from the model type.
|
||||
|
||||
Args:
|
||||
model_type: The model type to retrieve the supported features for.
|
||||
model_name: The name attribute of the model object, only used for the exception message.
|
||||
|
||||
Returns:
|
||||
The dictionary mapping each feature to a corresponding OnnxConfig constructor.
|
||||
"""
|
||||
model_type = model_type.lower()
|
||||
if model_type not in FeaturesManager._SUPPORTED_MODEL_TYPE:
|
||||
model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type
|
||||
raise KeyError(
|
||||
f"{model_type_and_model_name} is not supported yet. "
|
||||
f"Only {list(FeaturesManager._SUPPORTED_MODEL_TYPE.keys())} are supported. "
|
||||
f"If you want to support {model_type} please propose a PR or open up an issue."
|
||||
)
|
||||
return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type]
|
||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
|
||||
|
||||
@staticmethod
|
||||
def feature_to_task(feature: str) -> str:
|
||||
return feature.replace("-with-past", "")
|
||||
|
||||
@staticmethod
|
||||
def get_model_class_for_feature(feature: str) -> Type:
|
||||
def get_model_from_feature(feature: str, model: str):
|
||||
"""
|
||||
Attempt to retrieve an AutoModel class from a feature name.
|
||||
Attempt to retrieve a model from a model's name and the feature to be enabled.
|
||||
|
||||
Args:
|
||||
feature: The feature required.
|
||||
feature: The feature required
|
||||
model: The name of the model to export
|
||||
|
||||
Returns:
|
||||
The AutoModel class corresponding to the feature.
|
||||
|
||||
"""
|
||||
task = FeaturesManager.feature_to_task(feature)
|
||||
if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
|
||||
@ -241,43 +121,38 @@ class FeaturesManager:
|
||||
f"Unknown task: {feature}. "
|
||||
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
|
||||
)
|
||||
return FeaturesManager._TASKS_TO_AUTOMODELS[task]
|
||||
|
||||
def get_model_from_feature(feature: str, model: str) -> PreTrainedModel:
|
||||
"""
|
||||
Attempt to retrieve a model from a model's name and the feature to be enabled.
|
||||
|
||||
Args:
|
||||
feature: The feature required.
|
||||
model: The name of the model to export.
|
||||
|
||||
Returns:
|
||||
The instance of the model.
|
||||
|
||||
"""
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
return model_class.from_pretrained(model)
|
||||
return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model)
|
||||
|
||||
@staticmethod
|
||||
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
|
||||
"""
|
||||
Check whether or not the model has the requested features.
|
||||
Check whether or not the model has the requested features
|
||||
|
||||
Args:
|
||||
model: The model to export.
|
||||
feature: The name of the feature to check if it is available.
|
||||
model: The model to export
|
||||
feature: The name of the feature to check if it is available
|
||||
|
||||
Returns:
|
||||
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties.
|
||||
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
|
||||
|
||||
"""
|
||||
model_type = model.config.model_type.replace("_", "-")
|
||||
model_name = getattr(model, "name", "")
|
||||
model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
|
||||
model_name = f"({model_name})" if model_name else ""
|
||||
if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND:
|
||||
raise KeyError(
|
||||
f"{model.config.model_type} ({model_name}) is not supported yet. "
|
||||
f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. "
|
||||
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
|
||||
)
|
||||
|
||||
# Look for the features
|
||||
model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type]
|
||||
if feature not in model_features:
|
||||
raise ValueError(
|
||||
f"{model.config.model_type} doesn't support feature {feature}. "
|
||||
f"Supported values are: {model_features}"
|
||||
f"Supported values are: {list(model_features.keys())}"
|
||||
)
|
||||
|
||||
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
|
||||
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature]
|
||||
|
Binary file not shown.
@ -3,8 +3,33 @@ from tempfile import NamedTemporaryFile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import ( # LongformerConfig,; T5Config,
|
||||
AlbertConfig,
|
||||
AutoTokenizer,
|
||||
BartConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
GPTNeoConfig,
|
||||
LayoutLMConfig,
|
||||
MBartConfig,
|
||||
RobertaConfig,
|
||||
XLMRobertaConfig,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.models.albert import AlbertOnnxConfig
|
||||
from transformers.models.bart import BartOnnxConfig
|
||||
from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig
|
||||
from transformers.models.distilbert import DistilBertOnnxConfig
|
||||
|
||||
# from transformers.models.longformer import LongformerOnnxConfig
|
||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
||||
from transformers.models.layoutlm import LayoutLMOnnxConfig
|
||||
from transformers.models.mbart import MBartOnnxConfig
|
||||
from transformers.models.roberta import RobertaOnnxConfig
|
||||
|
||||
# from transformers.models.t5 import T5OnnxConfig
|
||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
from transformers.onnx import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
@ -12,8 +37,7 @@ from transformers.onnx import (
|
||||
export,
|
||||
validate_model_outputs,
|
||||
)
|
||||
from transformers.onnx.config import OnnxConfigWithPast
|
||||
from transformers.onnx.features import FeaturesManager
|
||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
from transformers.testing_utils import require_onnx, require_torch, slow
|
||||
|
||||
@ -115,12 +139,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
|
||||
"""
|
||||
|
||||
SUPPORTED_WITH_PAST_CONFIGS = {}
|
||||
# SUPPORTED_WITH_PAST_CONFIGS = {
|
||||
# ("BART", BartConfig),
|
||||
# ("GPT2", GPT2Config),
|
||||
# # ("T5", T5Config)
|
||||
# }
|
||||
SUPPORTED_WITH_PAST_CONFIGS = {
|
||||
("BART", BartConfig),
|
||||
("GPT2", GPT2Config),
|
||||
# ("T5", T5Config)
|
||||
}
|
||||
|
||||
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
|
||||
def test_use_past(self):
|
||||
@ -164,37 +187,40 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
)
|
||||
|
||||
|
||||
PYTORCH_EXPORT_MODELS = {
|
||||
("albert", "hf-internal-testing/tiny-albert"),
|
||||
("bert", "bert-base-cased"),
|
||||
("camembert", "camembert-base"),
|
||||
("distilbert", "distilbert-base-cased"),
|
||||
# ("longFormer", "longformer-base-4096"),
|
||||
("roberta", "roberta-base"),
|
||||
("xlm-roberta", "xlm-roberta-base"),
|
||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||
}
|
||||
if is_torch_available():
|
||||
from transformers import ( # T5Model,
|
||||
AlbertModel,
|
||||
BartModel,
|
||||
BertModel,
|
||||
DistilBertModel,
|
||||
GPT2Model,
|
||||
GPTNeoModel,
|
||||
LayoutLMModel,
|
||||
MBartModel,
|
||||
RobertaModel,
|
||||
XLMRobertaModel,
|
||||
)
|
||||
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
("gpt2", "gpt2"),
|
||||
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
||||
}
|
||||
PYTORCH_EXPORT_DEFAULT_MODELS = {
|
||||
("ALBERT", "hf-internal-testing/tiny-albert", AlbertModel, AlbertConfig, AlbertOnnxConfig),
|
||||
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
|
||||
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
|
||||
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
|
||||
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||
("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
|
||||
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||
("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig),
|
||||
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
|
||||
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||
("bart", "facebook/bart-base"),
|
||||
("mbart", "sshleifer/tiny-mbart"),
|
||||
("t5", "t5-small"),
|
||||
}
|
||||
|
||||
|
||||
def _get_models_to_test(export_models_list):
|
||||
models_to_test = []
|
||||
for (name, model) in export_models_list:
|
||||
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
||||
name
|
||||
).items():
|
||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||
return models_to_test
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
# ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
|
||||
# ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
|
||||
}
|
||||
|
||||
|
||||
class OnnxExportTestCaseV2(TestCase):
|
||||
@ -202,52 +228,52 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
Integration tests ensuring supported models are correctly exported
|
||||
"""
|
||||
|
||||
def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_default(self):
|
||||
from transformers.onnx import export
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
|
||||
with self.subTest(name):
|
||||
self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
|
||||
|
||||
# Useful for causal lm models that do not use pad tokens.
|
||||
if not getattr(config, "pad_token_id", None):
|
||||
config.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
model = model_class(config_class.from_pretrained(model))
|
||||
onnx_config = onnx_config_class.from_model_config(model.config)
|
||||
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
model = model_class.from_config(config)
|
||||
onnx_config = onnx_config_class_constructor(model.config)
|
||||
with NamedTemporaryFile("w") as output:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
|
||||
)
|
||||
|
||||
with NamedTemporaryFile("w") as output:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
tokenizer, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
|
||||
)
|
||||
try:
|
||||
validate_model_outputs(
|
||||
onnx_config,
|
||||
tokenizer,
|
||||
model,
|
||||
Path(output.name),
|
||||
onnx_outputs,
|
||||
onnx_config.atol_for_validation,
|
||||
try:
|
||||
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
|
||||
except ValueError as ve:
|
||||
self.fail(f"{name} -> {ve}")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_with_past(self):
|
||||
from transformers.onnx import export
|
||||
|
||||
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS:
|
||||
with self.subTest(name):
|
||||
self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
model = model_class(config_class())
|
||||
onnx_config = onnx_config_class.with_past(model.config)
|
||||
|
||||
self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.")
|
||||
self.assertTrue(
|
||||
onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()"
|
||||
)
|
||||
except ValueError as ve:
|
||||
self.fail(f"{name}, {feature} -> {ve}")
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
with NamedTemporaryFile("w") as output:
|
||||
output = Path(output.name)
|
||||
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_seq2seq_with_past(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
try:
|
||||
validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5)
|
||||
except ValueError as ve:
|
||||
self.fail(f"{name} -> {ve}")
|
||||
|
Loading…
Reference in New Issue
Block a user