Fix the bug in constructing the all_hidden_states of DeBERTa v2 (#10466)

* fix all_hidden_states

* use output_states instead of next_kv
This commit is contained in:
felixgwu 2021-03-03 12:05:21 -05:00 committed by GitHub
parent 188574ac50
commit d064fb5647
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -450,10 +450,11 @@ class DebertaV2Encoder(nn.Module):
else:
next_kv = hidden_states
rel_embeddings = self.get_rel_embedding()
output_states = next_kv
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
all_hidden_states = all_hidden_states + (output_states,)
output_states = layer_module(
next_kv,