This commit is contained in:
VictorSanh 2019-08-28 00:57:16 +00:00
parent a8ad83040d
commit 5d29f8e99b

View File

@ -274,7 +274,8 @@ class TransformerBlock(nn.Module):
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
if self.output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else:
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(sa_output) == tuple
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
@ -329,6 +330,9 @@ class Transformer(nn.Module):
if self.output_attentions:
attentions, hidden_state = hidden_state
all_attentions = all_attentions + (attentions,)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(hidden_state) == tuple
hidden_state = hidden_state[0]
all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,)