mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix bugs
This commit is contained in:
parent
a8ad83040d
commit
5d29f8e99b
@ -274,7 +274,8 @@ class TransformerBlock(nn.Module):
|
||||
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
|
||||
if self.output_attentions:
|
||||
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
|
||||
else:
|
||||
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
|
||||
assert type(sa_output) == tuple
|
||||
sa_output = sa_output[0]
|
||||
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
|
||||
|
||||
@ -329,6 +330,9 @@ class Transformer(nn.Module):
|
||||
if self.output_attentions:
|
||||
attentions, hidden_state = hidden_state
|
||||
all_attentions = all_attentions + (attentions,)
|
||||
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
|
||||
assert type(hidden_state) == tuple
|
||||
hidden_state = hidden_state[0]
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
outputs = (hidden_state,)
|
||||
|
Loading…
Reference in New Issue
Block a user