fix BLOOM ONNX config (#19573)

* fix BLOOM ONNX config
- `value` params have `seq_len` as their 2nd axe as opposed to other models which have it as 3rd

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Nouamane Tazi 2022-10-14 15:04:48 +01:00 committed by GitHub
parent 4f0337a08f
commit 1967be98fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 4 deletions

View File

@ -152,7 +152,6 @@ class BloomConfig(PretrainedConfig):
class BloomOnnxConfig(OnnxConfigWithPast):
torch_onnx_minimum_version = version.parse("1.12")
def __init__(
@ -171,7 +170,8 @@ class BloomOnnxConfig(OnnxConfigWithPast):
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
# BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}

View File

@ -486,7 +486,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
return common_inputs
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
def fill_with_past_key_values_(
self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str, inverted_values_shape: bool = False
):
"""
Fill the input_or_outputs mapping with past_key_values dynamic axes considering.
@ -494,6 +496,8 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
inputs_or_outputs: The mapping to fill.
direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the
output mapping, this is important for axes naming.
inverted_values_shape:
If `True`, store values on dynamic axis 1, else on axis 2.
"""
if direction not in ["inputs", "outputs"]:
@ -502,7 +506,10 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
name = "past_key_values" if direction == "inputs" else "present"
for i in range(self.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
if inverted_values_shape:
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"}
else:
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key"] = t[0]