mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
add GPT-J ONNX config to Transformers (#16274)
* add GPT-J ONNX config to Transformers * remove token-classification features mapping Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * add question-answering features mapping Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * add GPT2 config init to GPT2 config + copie shebang for fix-copies Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
parent
aff9bc405a
commit
029b0d95ed
@ -54,6 +54,7 @@ Ready-made configurations include the following architectures:
|
||||
- ELECTRA
|
||||
- FlauBERT
|
||||
- GPT Neo
|
||||
- GPT-J
|
||||
- I-BERT
|
||||
- LayoutLM
|
||||
- M2M100
|
||||
|
@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
|
||||
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
@ -43,7 +43,7 @@ if is_flax_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
|
||||
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig, GPTJOnnxConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_gptj import (
|
||||
|
@ -13,8 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" GPT-J model configuration"""
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -135,3 +139,84 @@ class GPTJConfig(PretrainedConfig):
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
|
||||
class GPTJOnnxConfig(OnnxConfigWithPast):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
task: str = "default",
|
||||
patching_specs: List[PatchingSpec] = None,
|
||||
use_past: bool = False,
|
||||
):
|
||||
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
||||
if not getattr(self._config, "pad_token_id", None):
|
||||
# TODO: how to do that better?
|
||||
self._config.pad_token_id = 0
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self._config.n_layer
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self._config.n_head
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
# We need to order the input in the way they appears in the forward()
|
||||
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||
|
||||
# Need to add the past_keys
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
past_shape = (
|
||||
batch,
|
||||
self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
ordered_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||
]
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 13
|
||||
|
@ -11,6 +11,7 @@ from ..models.electra import ElectraOnnxConfig
|
||||
from ..models.flaubert import FlaubertOnnxConfig
|
||||
from ..models.gpt2 import GPT2OnnxConfig
|
||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||
from ..models.gptj import GPTJOnnxConfig
|
||||
from ..models.ibert import IBertOnnxConfig
|
||||
from ..models.layoutlm import LayoutLMOnnxConfig
|
||||
from ..models.m2m_100 import M2M100OnnxConfig
|
||||
@ -233,6 +234,15 @@ class FeaturesManager:
|
||||
"token-classification",
|
||||
onnx_config_cls=GPT2OnnxConfig,
|
||||
),
|
||||
"gpt-j": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"question-answering",
|
||||
"sequence-classification",
|
||||
onnx_config_cls=GPTJOnnxConfig,
|
||||
),
|
||||
"gpt-neo": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
|
Loading…
Reference in New Issue
Block a user