add GPT-J ONNX config to Transformers (#16274)

* add GPT-J ONNX config to Transformers

* remove token-classification features mapping

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add question-answering features mapping

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add GPT2 config init to GPT2 config + copie shebang for fix-copies

Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Thomas Chaigneau 2022-03-23 21:36:11 +01:00 committed by GitHub
parent aff9bc405a
commit 029b0d95ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 2 deletions

View File

@ -54,6 +54,7 @@ Ready-made configurations include the following architectures:
- ELECTRA
- FlauBERT
- GPT Neo
- GPT-J
- I-BERT
- LayoutLM
- M2M100

View File

@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = {
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"],
}
if is_torch_available():
@ -43,7 +43,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig, GPTJOnnxConfig
if is_torch_available():
from .modeling_gptj import (

View File

@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" GPT-J model configuration"""
from collections import OrderedDict
from typing import Any, List, Mapping, Optional
from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging
@ -135,3 +139,84 @@ class GPTJConfig(PretrainedConfig):
super().__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
)
# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
class GPTJOnnxConfig(OnnxConfigWithPast):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
self._config.pad_token_id = 0
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
@property
def num_layers(self) -> int:
return self._config.n_layer
@property
def num_attention_heads(self) -> int:
return self._config.n_head
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
return ordered_inputs
@property
def default_onnx_opset(self) -> int:
return 13

View File

@ -11,6 +11,7 @@ from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.gptj import GPTJOnnxConfig
from ..models.ibert import IBertOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.m2m_100 import M2M100OnnxConfig
@ -233,6 +234,15 @@ class FeaturesManager:
"token-classification",
onnx_config_cls=GPT2OnnxConfig,
),
"gpt-j": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"question-answering",
"sequence-classification",
onnx_config_cls=GPTJOnnxConfig,
),
"gpt-neo": supported_features_mapping(
"default",
"default-with-past",