mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
XLNET can be exported to TorchScript
This commit is contained in:
parent
be54b16960
commit
971c24687f
@ -384,7 +384,8 @@ class XLNetRelativeAttention(nn.Module):
|
||||
x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
|
||||
x = x[1:, ...]
|
||||
x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
|
||||
x = x[:, 0:klen, :, :]
|
||||
# x = x[:, 0:klen, :, :]
|
||||
x = torch.index_select(x, 1, torch.arange(klen))
|
||||
|
||||
return x
|
||||
|
||||
@ -527,9 +528,9 @@ class XLNetRelativeAttention(nn.Module):
|
||||
output_h = self.post_attention(h, attn_vec)
|
||||
output_g = None
|
||||
|
||||
outputs = [output_h, output_g]
|
||||
outputs = (output_h, output_g)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + [attn_prob]
|
||||
outputs += (attn_prob,)
|
||||
return outputs
|
||||
|
||||
class XLNetFeedForward(nn.Module):
|
||||
@ -574,7 +575,7 @@ class XLNetLayer(nn.Module):
|
||||
output_g = self.ff(output_g)
|
||||
output_h = self.ff(output_h)
|
||||
|
||||
outputs = [output_h, output_g] + outputs[2:] # Add again attentions if there are there
|
||||
outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
|
||||
return outputs
|
||||
|
||||
|
||||
@ -688,7 +689,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
def relative_positional_encoding(self, qlen, klen, bsz=None):
|
||||
"""create relative positional encoding."""
|
||||
freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
|
||||
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
|
||||
inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))
|
||||
|
||||
if self.attn_type == 'bi':
|
||||
# beg, end = klen - 1, -qlen
|
||||
@ -869,7 +870,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.n_layer
|
||||
|
||||
new_mems = []
|
||||
new_mems = ()
|
||||
if mems is None:
|
||||
mems = [None] * len(self.layer)
|
||||
|
||||
@ -877,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
hidden_states = []
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# cache new mems
|
||||
new_mems.append(self.cache_mem(output_h, mems[i]))
|
||||
new_mems += (self.cache_mem(output_h, mems[i]),)
|
||||
if self.output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
@ -895,16 +896,16 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
output = self.dropout(output_g if output_g is not None else output_h)
|
||||
|
||||
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||
outputs = [output.permute(1, 0, 2).contiguous(), new_mems]
|
||||
outputs = (output.permute(1, 0, 2).contiguous(), new_mems)
|
||||
if self.output_hidden_states:
|
||||
if output_g is not None:
|
||||
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs]
|
||||
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
|
||||
else:
|
||||
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
|
||||
outputs.append(hidden_states)
|
||||
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
|
||||
outputs += (hidden_states,)
|
||||
if self.output_attentions:
|
||||
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
||||
outputs.append(attentions)
|
||||
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
||||
outputs += (attentions,)
|
||||
|
||||
return outputs # outputs, new_mems, (hidden_states), (attentions)
|
||||
|
||||
@ -986,7 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the embeddings
|
||||
"""
|
||||
self.lm_loss.weight = self.transformer.word_embedding.weight
|
||||
self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone())
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||
@ -1026,14 +1027,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
|
||||
logits = self.lm_loss(transformer_outputs[0])
|
||||
|
||||
outputs = [logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||
|
||||
if labels is not None:
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(logits.view(-1, logits.size(-1)),
|
||||
labels.view(-1))
|
||||
outputs = [loss] + outputs
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
||||
@ -1061,7 +1062,7 @@ class XLNetSequenceSummary(nn.Module):
|
||||
output = hidden_states[:, 0]
|
||||
elif self.summary_type == 'mean':
|
||||
output = hidden_states.mean(dim=1)
|
||||
elif summary_type == 'attn':
|
||||
elif self.summary_type == 'attn':
|
||||
raise NotImplementedError
|
||||
|
||||
output = self.summary(output)
|
||||
@ -1180,7 +1181,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
output = self.sequence_summary(output)
|
||||
logits = self.logits_proj(output)
|
||||
|
||||
outputs = [logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
@ -1190,7 +1191,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = [loss] + outputs
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
||||
@ -1271,7 +1272,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
outputs = [start_logits, end_logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||
outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
@ -1288,6 +1289,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
outputs = [total_loss] + outputs
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
Loading…
Reference in New Issue
Block a user