fix last element in hidden_states for XGLM (#16301)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-03-21 17:38:52 +01:00 committed by GitHub
parent 5a42bb431e
commit 4b2774832d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -534,13 +534,18 @@ class FlaxXGLMModule(nn.Module):
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
outputs = (last_hidden_states,) + outputs[1:]
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)