mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix BLOOM ONNX config (#19573)
* fix BLOOM ONNX config - `value` params have `seq_len` as their 2nd axe as opposed to other models which have it as 3rd Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
parent
4f0337a08f
commit
1967be98fa
@ -152,7 +152,6 @@ class BloomConfig(PretrainedConfig):
|
||||
|
||||
|
||||
class BloomOnnxConfig(OnnxConfigWithPast):
|
||||
|
||||
torch_onnx_minimum_version = version.parse("1.12")
|
||||
|
||||
def __init__(
|
||||
@ -171,7 +170,8 @@ class BloomOnnxConfig(OnnxConfigWithPast):
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
# BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
|
@ -486,7 +486,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
|
||||
return common_inputs
|
||||
|
||||
def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
|
||||
def fill_with_past_key_values_(
|
||||
self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str, inverted_values_shape: bool = False
|
||||
):
|
||||
"""
|
||||
Fill the input_or_outputs mapping with past_key_values dynamic axes considering.
|
||||
|
||||
@ -494,6 +496,8 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
inputs_or_outputs: The mapping to fill.
|
||||
direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the
|
||||
output mapping, this is important for axes naming.
|
||||
inverted_values_shape:
|
||||
If `True`, store values on dynamic axis 1, else on axis 2.
|
||||
|
||||
"""
|
||||
if direction not in ["inputs", "outputs"]:
|
||||
@ -502,7 +506,10 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
name = "past_key_values" if direction == "inputs" else "present"
|
||||
for i in range(self.num_layers):
|
||||
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
if inverted_values_shape:
|
||||
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"}
|
||||
else:
|
||||
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
|
||||
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
||||
flattened_output[f"{name}.{idx}.key"] = t[0]
|
||||
|
Loading…
Reference in New Issue
Block a user