mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
fix hidden-state extraction
This commit is contained in:
parent
0c2ff34815
commit
d939d6fd02
@ -855,19 +855,22 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# cache new mems
|
||||
new_mems.append(self.cache_mem(output_h, mems[i]))
|
||||
hidden_states.append((output_h, output_g))
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
output_h, output_g = layer_module(output_h, output_g,
|
||||
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
|
||||
r=pos_emb, seg_mat=seg_mat,
|
||||
mems=mems[i], target_mapping=target_mapping,
|
||||
head_mask=head_mask)
|
||||
hidden_states.append((output_h, output_g))
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
output = self.dropout(output_g if output_g is not None else output_h)
|
||||
|
||||
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||
output = output.permute(1, 0, 2).contiguous()
|
||||
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
|
||||
if output_g is not None:
|
||||
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs]
|
||||
else:
|
||||
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
|
||||
|
||||
return output, hidden_states, new_mems
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user