fix hidden-state extraction

This commit is contained in:
thomwolf 2019-06-27 09:39:44 +02:00
parent 0c2ff34815
commit d939d6fd02

View File

@ -855,19 +855,22 @@ class XLNetModel(XLNetPreTrainedModel):
for i, layer_module in enumerate(self.layer):
# cache new mems
new_mems.append(self.cache_mem(output_h, mems[i]))
hidden_states.append((output_h, output_g))
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output_h, output_g = layer_module(output_h, output_g,
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat,
mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask)
hidden_states.append((output_h, output_g))
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = output.permute(1, 0, 2).contiguous()
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
if output_g is not None:
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs]
else:
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
return output, hidden_states, new_mems