mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
added support for exporting of t5 to onnx with past_key_values (#10651)
This commit is contained in:
parent
50f4539b82
commit
5c00918681
@ -423,6 +423,8 @@ class T5Attention(nn.Module):
|
|||||||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
||||||
batch_size, seq_length = hidden_states.shape[:2]
|
batch_size, seq_length = hidden_states.shape[:2]
|
||||||
|
|
||||||
|
int_seq_length = int(seq_length)
|
||||||
|
|
||||||
real_seq_length = seq_length
|
real_seq_length = seq_length
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -489,7 +491,7 @@ class T5Attention(nn.Module):
|
|||||||
# if key and values are already calculated
|
# if key and values are already calculated
|
||||||
# we want only the last query position bias
|
# we want only the last query position bias
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
position_bias = position_bias[:, :, -seq_length:, :]
|
position_bias = position_bias[:, :, -int_seq_length:, :]
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||||
|
Loading…
Reference in New Issue
Block a user