From 5c00918681d6b4027701eb46cea8f795da0d4064 Mon Sep 17 00:00:00 2001 From: Kiran R Date: Fri, 23 Apr 2021 21:44:20 +0530 Subject: [PATCH] added support for exporting of t5 to onnx with past_key_values (#10651) --- src/transformers/models/t5/modeling_t5.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 6d256d63f86..2779258ed07 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -423,6 +423,8 @@ class T5Attention(nn.Module): # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] + int_seq_length = int(seq_length) + real_seq_length = seq_length if past_key_value is not None: @@ -489,7 +491,7 @@ class T5Attention(nn.Module): # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -seq_length:, :] + position_bias = position_bias[:, :, -int_seq_length:, :] if mask is not None: position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)