mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Merge pull request #1695 from huggingface/models_inputs_embeds
model forwards can take an inputs_embeds param
This commit is contained in:
commit
7daacf00df
@ -255,6 +255,10 @@ XXX_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
|
||||||
|
@ -238,6 +238,10 @@ XXX_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@ -295,7 +299,7 @@ class XxxModel(XxxPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
@ -449,14 +453,15 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||||
position_ids=None, head_mask=None, labels=None):
|
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
|
|
||||||
outputs = self.transformer(input_ids,
|
outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
@ -520,14 +525,15 @@ class XxxForTokenClassification(XxxPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||||
position_ids=None, head_mask=None, labels=None):
|
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
|
|
||||||
outputs = self.transformer(input_ids,
|
outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
@ -603,14 +609,15 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
start_positions=None, end_positions=None):
|
start_positions=None, end_positions=None):
|
||||||
|
|
||||||
outputs = self.transformer(input_ids,
|
outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
@ -158,19 +158,26 @@ class BertEmbeddings(nn.Module):
|
|||||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||||
seq_length = input_ids.size(1)
|
if input_ids is not None:
|
||||||
if position_ids is None:
|
input_shape = input_ids.size()
|
||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
else:
|
||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
if token_type_ids is None:
|
|
||||||
token_type_ids = torch.zeros_like(input_ids)
|
|
||||||
|
|
||||||
words_embeddings = self.word_embeddings(input_ids)
|
seq_length = input_shape[1]
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||||
embeddings = self.LayerNorm(embeddings)
|
embeddings = self.LayerNorm(embeddings)
|
||||||
embeddings = self.dropout(embeddings)
|
embeddings = self.dropout(embeddings)
|
||||||
return embeddings
|
return embeddings
|
||||||
@ -550,6 +557,10 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
|
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
|
||||||
is configured as a decoder.
|
is configured as a decoder.
|
||||||
@ -615,8 +626,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
|
||||||
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||||
""" Forward pass on the Model.
|
""" Forward pass on the Model.
|
||||||
|
|
||||||
The model can behave as an encoder (with only self-attention) as well
|
The model can behave as an encoder (with only self-attention) as well
|
||||||
@ -632,12 +643,23 @@ class BertModel(BertPreTrainedModel):
|
|||||||
https://arxiv.org/abs/1706.03762
|
https://arxiv.org/abs/1706.03762
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones(input_shape)
|
||||||
if encoder_attention_mask is None:
|
if encoder_attention_mask is None:
|
||||||
encoder_attention_mask = torch.ones_like(input_ids)
|
encoder_attention_mask = torch.ones(input_shape)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros_like(input_ids)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
|
||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
@ -649,8 +671,8 @@ class BertModel(BertPreTrainedModel):
|
|||||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
if attention_mask.dim() == 2:
|
if attention_mask.dim() == 2:
|
||||||
if self.config.is_decoder:
|
if self.config.is_decoder:
|
||||||
batch_size, seq_length = input_ids.size()
|
batch_size, seq_length = input_shape
|
||||||
seq_ids = torch.arange(seq_length, device=input_ids.device)
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
else:
|
else:
|
||||||
@ -689,7 +711,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
|
||||||
encoder_outputs = self.encoder(embedding_output,
|
encoder_outputs = self.encoder(embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@ -754,14 +776,15 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
masked_lm_labels=None, next_sentence_label=None):
|
masked_lm_labels=None, next_sentence_label=None):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
sequence_output, pooled_output = outputs[:2]
|
sequence_output, pooled_output = outputs[:2]
|
||||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||||
@ -829,7 +852,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
|
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(input_ids,
|
||||||
@ -837,6 +860,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask)
|
encoder_attention_mask=encoder_attention_mask)
|
||||||
|
|
||||||
@ -908,14 +932,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
next_sentence_label=None):
|
next_sentence_label=None):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
@ -975,14 +1000,15 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||||
position_ids=None, head_mask=None, labels=None):
|
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
@ -1049,8 +1075,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||||
position_ids=None, head_mask=None, labels=None):
|
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1]
|
||||||
|
|
||||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||||
@ -1062,7 +1088,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
@ -1123,14 +1150,15 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||||
position_ids=None, head_mask=None, labels=None):
|
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
@ -1207,14 +1235,15 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
start_positions=None, end_positions=None):
|
start_positions=None, end_positions=None):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
@ -236,6 +236,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@ -302,17 +306,26 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.h[layer].attn.prune_heads(heads)
|
self.h[layer].attn.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
||||||
input_shape = input_ids.size()
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if past is None:
|
if past is None:
|
||||||
past_length = 0
|
past_length = 0
|
||||||
past = [None] * len(self.h)
|
past = [None] * len(self.h)
|
||||||
else:
|
else:
|
||||||
past_length = past[0][0].size(-2)
|
past_length = past[0][0].size(-2)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
|
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -354,9 +367,10 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
token_type_embeds = 0
|
token_type_embeds = 0
|
||||||
position_ids = position_ids.view(-1, input_shape[-1])
|
position_ids = position_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
inputs_embeds = self.w(input_ids)
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.w(input_ids)
|
||||||
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||||
seq_len = input_ids.shape[-1]
|
seq_len = input_shape[-1]
|
||||||
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
|
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
|
||||||
|
|
||||||
inputs_embeds *= np.sqrt(self.d_model_size)
|
inputs_embeds *= np.sqrt(self.d_model_size)
|
||||||
@ -455,14 +469,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
labels=None):
|
labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
past=past,
|
past=past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
@ -387,6 +387,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@ -436,9 +440,18 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
self.transformer.layer[layer].attention.prune_heads(heads)
|
self.transformer.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids, attention_mask=None, head_mask=None):
|
input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None):
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
|
attention_mask = torch.ones(input_shape) # (bs, seq_length)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@ -455,8 +468,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim)
|
if inputs_embeds is None:
|
||||||
tfmr_output = self.transformer(x=embedding_output,
|
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
|
||||||
|
tfmr_output = self.transformer(x=inputs_embeds,
|
||||||
attn_mask=attention_mask,
|
attn_mask=attention_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask)
|
||||||
hidden_state = tfmr_output[0]
|
hidden_state = tfmr_output[0]
|
||||||
@ -514,10 +528,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.vocab_projector
|
return self.vocab_projector
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
|
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, masked_lm_labels=None):
|
||||||
dlbrt_output = self.distilbert(input_ids=input_ids,
|
dlbrt_output = self.distilbert(input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
|
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
|
||||||
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
|
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
|
||||||
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
|
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
|
||||||
@ -578,10 +593,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, head_mask=None, labels=None):
|
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
distilbert_output = self.distilbert(input_ids=input_ids,
|
distilbert_output = self.distilbert(input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
||||||
pooled_output = hidden_state[:, 0] # (bs, dim)
|
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||||
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
||||||
@ -652,10 +668,11 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, head_mask=None, start_positions=None, end_positions=None):
|
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
|
||||||
distilbert_output = self.distilbert(input_ids=input_ids,
|
distilbert_output = self.distilbert(input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
|
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
|
||||||
|
@ -313,6 +313,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@ -370,9 +374,17 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.h[layer].attn.prune_heads(heads)
|
self.h[layer].attn.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
||||||
input_shape = input_ids.size()
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||||
if position_ids is not None:
|
if position_ids is not None:
|
||||||
@ -384,8 +396,9 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
past_length = past[0][0].size(-2)
|
past_length = past[0][0].size(-2)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
|
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -419,7 +432,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.n_layer
|
head_mask = [None] * self.config.n_layer
|
||||||
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
position_embeds = self.wpe(position_ids)
|
position_embeds = self.wpe(position_ids)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_embeds = self.wte(token_type_ids)
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
@ -520,14 +534,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
labels=None):
|
labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
past=past,
|
past=past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
@ -623,14 +638,15 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
past=past,
|
past=past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
@ -322,6 +322,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
|
||||||
@ -373,14 +377,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.h[layer].attn.prune_heads(heads)
|
self.h[layer].attn.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
# This was used when we had a single embedding matrice from position and token embeddings
|
# Code is different from when we had a single embedding matrice from position and token embeddings
|
||||||
# start = self.config.vocab_size + self.config.n_special
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
# end = start + input_ids.size(-1)
|
position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device)
|
||||||
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
|
||||||
|
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -413,11 +425,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.n_layer
|
head_mask = [None] * self.config.n_layer
|
||||||
|
|
||||||
input_shape = input_ids.size()
|
if inputs_embeds is None:
|
||||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
inputs_embeds = self.tokens_embed(input_ids)
|
||||||
position_ids = position_ids.view(-1, position_ids.size(-1))
|
|
||||||
|
|
||||||
inputs_embeds = self.tokens_embed(input_ids)
|
|
||||||
position_embeds = self.positions_embed(position_ids)
|
position_embeds = self.positions_embed(position_ids)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
||||||
@ -495,13 +504,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
labels=None):
|
labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
@ -587,13 +597,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
@ -48,16 +48,24 @@ class RobertaEmbeddings(BertEmbeddings):
|
|||||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size,
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size,
|
||||||
padding_idx=self.padding_idx)
|
padding_idx=self.padding_idx)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||||
seq_length = input_ids.size(1)
|
if input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
else:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
||||||
# cf. fairseq's `utils.make_positions`
|
# cf. fairseq's `utils.make_positions`
|
||||||
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
|
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=device)
|
||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
||||||
return super(RobertaEmbeddings, self).forward(input_ids,
|
return super(RobertaEmbeddings, self).forward(input_ids,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids)
|
position_ids=position_ids,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
|
||||||
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
|
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
|
||||||
@ -126,6 +134,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@ -222,13 +234,14 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
masked_lm_labels=None):
|
masked_lm_labels=None):
|
||||||
outputs = self.roberta(input_ids,
|
outputs = self.roberta(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.lm_head(sequence_output)
|
prediction_scores = self.lm_head(sequence_output)
|
||||||
|
|
||||||
@ -309,13 +322,14 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
|
|||||||
self.roberta = RobertaModel(config)
|
self.roberta = RobertaModel(config)
|
||||||
self.classifier = RobertaClassificationHead(config)
|
self.classifier = RobertaClassificationHead(config)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
|
||||||
labels=None):
|
labels=None):
|
||||||
outputs = self.roberta(input_ids,
|
outputs = self.roberta(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
logits = self.classifier(sequence_output)
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
@ -372,6 +386,10 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
Labels for computing the multiple choice classification loss.
|
Labels for computing the multiple choice classification loss.
|
||||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||||
@ -415,8 +433,8 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
|
||||||
position_ids=None, head_mask=None):
|
position_ids=None, head_mask=None, inputs_embeds=None):
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1]
|
||||||
|
|
||||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||||
@ -487,14 +505,15 @@ class RobertaForTokenClassification(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||||
position_ids=None, head_mask=None, labels=None):
|
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
|
|
||||||
outputs = self.roberta(input_ids,
|
outputs = self.roberta(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
@ -616,6 +616,10 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
|
||||||
|
@ -374,6 +374,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
@ -508,6 +508,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
|
||||||
|
@ -408,6 +408,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
|
||||||
|
@ -389,6 +389,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare OpenAI GPT transformer model outputing raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare OpenAI GPT transformer model outputing raw hidden-states without any specific head on top.",
|
||||||
|
@ -157,6 +157,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",
|
||||||
|
@ -626,6 +626,10 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
|
||||||
|
@ -35,7 +35,7 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
r""" Base class for all TF models.
|
r""" Base class for all TF models.
|
||||||
|
|
||||||
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||||
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
||||||
|
|
||||||
Class attributes (overridden by derived classes):
|
Class attributes (overridden by derived classes):
|
||||||
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||||
|
@ -530,6 +530,10 @@ XLM_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",
|
||||||
|
@ -762,6 +762,10 @@ XLNET_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
|
||||||
|
@ -553,6 +553,10 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@ -657,12 +661,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
logger.info("Head pruning is not implemented for Transformer-XL model")
|
logger.info("Head pruning is not implemented for Transformer-XL model")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def init_mems(self, data):
|
def init_mems(self, bsz):
|
||||||
if self.mem_len > 0:
|
if self.mem_len > 0:
|
||||||
mems = []
|
mems = []
|
||||||
param = next(self.parameters())
|
param = next(self.parameters())
|
||||||
for i in range(self.n_layer):
|
for i in range(self.n_layer):
|
||||||
empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model,
|
empty = torch.zeros(self.mem_len, bsz, self.config.d_model,
|
||||||
dtype=param.dtype, device=param.device)
|
dtype=param.dtype, device=param.device)
|
||||||
mems.append(empty)
|
mems.append(empty)
|
||||||
|
|
||||||
@ -693,15 +697,22 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
|
|
||||||
return new_mems
|
return new_mems
|
||||||
|
|
||||||
def forward(self, input_ids, mems=None, head_mask=None):
|
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None):
|
||||||
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
||||||
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
||||||
input_ids = input_ids.transpose(0, 1).contiguous()
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_ids = input_ids.transpose(0, 1).contiguous()
|
||||||
|
qlen, bsz = input_ids.size()
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
|
||||||
|
qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if mems is None:
|
if mems is None:
|
||||||
mems = self.init_mems(input_ids)
|
mems = self.init_mems(bsz)
|
||||||
|
|
||||||
qlen, bsz = input_ids.size()
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@ -718,7 +729,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
head_mask = [None] * self.n_layer
|
head_mask = [None] * self.n_layer
|
||||||
|
|
||||||
word_emb = self.word_emb(input_ids)
|
if inputs_embeds is not None:
|
||||||
|
word_emb = inputs_embeds
|
||||||
|
else:
|
||||||
|
word_emb = self.word_emb(input_ids)
|
||||||
|
|
||||||
mlen = mems[0].size(0) if mems is not None else 0
|
mlen = mems[0].size(0) if mems is not None else 0
|
||||||
klen = mlen + qlen
|
klen = mlen + qlen
|
||||||
@ -860,14 +874,18 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||||
|
|
||||||
def init_mems(self, data):
|
def init_mems(self, bsz):
|
||||||
return self.transformer.init_mems(data)
|
return self.transformer.init_mems(bsz)
|
||||||
|
|
||||||
def forward(self, input_ids, mems=None, head_mask=None, labels=None):
|
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
bsz = input_ids.size(0)
|
if input_ids is not None:
|
||||||
tgt_len = input_ids.size(1)
|
bsz, tgt_len = input_ids.size(0), input_ids.size(1)
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
bsz, tgt_len = inputs_embeds.size(0), inputs_embeds.size(1)
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask)
|
transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask, inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
last_hidden = transformer_outputs[0]
|
last_hidden = transformer_outputs[0]
|
||||||
pred_hid = last_hidden[:, -tgt_len:]
|
pred_hid = last_hidden[:, -tgt_len:]
|
||||||
|
@ -53,7 +53,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
r""" Base class for all models.
|
r""" Base class for all models.
|
||||||
|
|
||||||
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||||
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
||||||
|
|
||||||
Class attributes (overridden by derived classes):
|
Class attributes (overridden by derived classes):
|
||||||
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||||
|
@ -311,6 +311,10 @@ XLM_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare XLM Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare XLM Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@ -421,14 +425,21 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.attentions[layer].prune_heads(heads)
|
self.attentions[layer].prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||||
lengths=None, cache=None, head_mask=None): # removed: src_enc=None, src_len=None
|
lengths=None, cache=None, head_mask=None, inputs_embeds=None): # removed: src_enc=None, src_len=None
|
||||||
|
if input_ids is not None:
|
||||||
|
bs, slen = input_ids.size()
|
||||||
|
else:
|
||||||
|
bs, slen = inputs_embeds.size()[:-1]
|
||||||
|
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
lengths = (input_ids != self.pad_index).sum(dim=1).long()
|
if input_ids is not None:
|
||||||
|
lengths = (input_ids != self.pad_index).sum(dim=1).long()
|
||||||
|
else:
|
||||||
|
lengths = torch.LongTensor([slen]*bs)
|
||||||
# mask = input_ids != self.pad_index
|
# mask = input_ids != self.pad_index
|
||||||
|
|
||||||
# check inputs
|
# check inputs
|
||||||
bs, slen = input_ids.size()
|
|
||||||
assert lengths.size(0) == bs
|
assert lengths.size(0) == bs
|
||||||
assert lengths.max().item() <= slen
|
assert lengths.max().item() <= slen
|
||||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||||
@ -442,10 +453,12 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
# if self.is_decoder and src_enc is not None:
|
# if self.is_decoder and src_enc is not None:
|
||||||
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
# position_ids
|
# position_ids
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = input_ids.new((slen,)).long()
|
position_ids = torch.arange(slen, dtype=torch.long, device=device)
|
||||||
position_ids = torch.arange(slen, out=position_ids).unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
|
||||||
else:
|
else:
|
||||||
assert position_ids.size() == (bs, slen) # (slen, bs)
|
assert position_ids.size() == (bs, slen) # (slen, bs)
|
||||||
# position_ids = position_ids.transpose(0, 1)
|
# position_ids = position_ids.transpose(0, 1)
|
||||||
@ -471,7 +484,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
head_mask = [None] * self.n_layers
|
head_mask = [None] * self.n_layers
|
||||||
|
|
||||||
# do not recompute cached elements
|
# do not recompute cached elements
|
||||||
if cache is not None:
|
if cache is not None and input_ids is not None:
|
||||||
_slen = slen - cache['slen']
|
_slen = slen - cache['slen']
|
||||||
input_ids = input_ids[:, -_slen:]
|
input_ids = input_ids[:, -_slen:]
|
||||||
position_ids = position_ids[:, -_slen:]
|
position_ids = position_ids[:, -_slen:]
|
||||||
@ -481,8 +494,10 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
attn_mask = attn_mask[:, -_slen:]
|
attn_mask = attn_mask[:, -_slen:]
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
tensor = self.embeddings(input_ids)
|
if inputs_embeds is None:
|
||||||
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
|
inputs_embeds = self.embeddings(input_ids)
|
||||||
|
|
||||||
|
tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
|
||||||
if langs is not None and self.use_lang_emb:
|
if langs is not None and self.use_lang_emb:
|
||||||
tensor = tensor + self.lang_embeddings(langs)
|
tensor = tensor + self.lang_embeddings(langs)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
@ -624,8 +639,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.pred_layer.proj
|
return self.pred_layer.proj
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||||
lengths=None, cache=None, head_mask=None, labels=None):
|
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
langs=langs,
|
langs=langs,
|
||||||
@ -633,7 +648,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
outputs = self.pred_layer(output, labels)
|
outputs = self.pred_layer(output, labels)
|
||||||
@ -685,8 +701,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||||
lengths=None, cache=None, head_mask=None, labels=None):
|
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
langs=langs,
|
langs=langs,
|
||||||
@ -694,7 +710,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
logits = self.sequence_summary(output)
|
logits = self.sequence_summary(output)
|
||||||
@ -768,8 +785,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||||
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None):
|
lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
langs=langs,
|
langs=langs,
|
||||||
@ -777,7 +794,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
sequence_output = transformer_outputs[0]
|
sequence_output = transformer_outputs[0]
|
||||||
|
|
||||||
@ -863,8 +881,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||||
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None,
|
lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None,
|
||||||
is_impossible=None, cls_index=None, p_mask=None):
|
is_impossible=None, cls_index=None, p_mask=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -873,7 +891,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
|
|
||||||
|
@ -558,6 +558,10 @@ XLNET_INPUTS_DOCSTRING = r"""
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings("The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
|
@add_start_docstrings("The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@ -712,19 +716,29 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
pos_emb = pos_emb.to(next(self.parameters()))
|
pos_emb = pos_emb.to(next(self.parameters()))
|
||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||||
token_type_ids=None, input_mask=None, head_mask=None):
|
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None):
|
||||||
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
|
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
|
||||||
# but we want a unified interface in the library with the batch size on the first dimension
|
# but we want a unified interface in the library with the batch size on the first dimension
|
||||||
# so we move here the first dimension (batch) to the end
|
# so we move here the first dimension (batch) to the end
|
||||||
input_ids = input_ids.transpose(0, 1).contiguous()
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_ids = input_ids.transpose(0, 1).contiguous()
|
||||||
|
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
inputs_embeds.transpose(0, 1).contiguous()
|
||||||
|
qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
|
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
|
||||||
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
|
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
|
||||||
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
|
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
|
||||||
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
|
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
|
||||||
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
|
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
|
||||||
|
|
||||||
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
|
|
||||||
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
|
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
|
||||||
klen = mlen + qlen
|
klen = mlen + qlen
|
||||||
|
|
||||||
@ -777,7 +791,10 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
non_tgt_mask = None
|
non_tgt_mask = None
|
||||||
|
|
||||||
##### Word embeddings and prepare h & g hidden states
|
##### Word embeddings and prepare h & g hidden states
|
||||||
word_emb_k = self.word_embedding(input_ids)
|
if inputs_embeds is not None:
|
||||||
|
word_emb_k = inputs_embeds
|
||||||
|
else:
|
||||||
|
word_emb_k = self.word_embedding(input_ids)
|
||||||
output_h = self.dropout(word_emb_k)
|
output_h = self.dropout(word_emb_k)
|
||||||
if target_mapping is not None:
|
if target_mapping is not None:
|
||||||
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
|
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
|
||||||
@ -924,8 +941,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_loss
|
return self.lm_loss
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||||
token_type_ids=None, input_mask=None, head_mask=None, labels=None):
|
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
mems=mems,
|
mems=mems,
|
||||||
@ -933,7 +950,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
target_mapping=target_mapping,
|
target_mapping=target_mapping,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
logits = self.lm_loss(transformer_outputs[0])
|
logits = self.lm_loss(transformer_outputs[0])
|
||||||
|
|
||||||
@ -998,8 +1016,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||||
token_type_ids=None, input_mask=None, head_mask=None, labels=None):
|
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
mems=mems,
|
mems=mems,
|
||||||
@ -1007,7 +1025,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
target_mapping=target_mapping,
|
target_mapping=target_mapping,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
|
|
||||||
output = self.sequence_summary(output)
|
output = self.sequence_summary(output)
|
||||||
@ -1049,6 +1068,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||||
|
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||||
|
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
Labels for computing the multiple choice classification loss.
|
Labels for computing the multiple choice classification loss.
|
||||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||||
@ -1093,9 +1116,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids=None, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None,
|
mems=None, perm_mask=None, target_mapping=None,
|
||||||
labels=None, head_mask=None):
|
labels=None, head_mask=None, inputs_embeds=None):
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1]
|
||||||
|
|
||||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||||
@ -1106,7 +1129,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids,
|
transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids,
|
||||||
input_mask=flat_input_mask, attention_mask=flat_attention_mask,
|
input_mask=flat_input_mask, attention_mask=flat_attention_mask,
|
||||||
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
|
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask, inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
@ -1178,8 +1201,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||||
token_type_ids=None, input_mask=None, head_mask=None,
|
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
|
||||||
start_positions=None, end_positions=None):
|
start_positions=None, end_positions=None):
|
||||||
|
|
||||||
outputs = self.transformer(input_ids,
|
outputs = self.transformer(input_ids,
|
||||||
@ -1189,7 +1212,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
|||||||
target_mapping=target_mapping,
|
target_mapping=target_mapping,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
@ -1294,8 +1318,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||||
token_type_ids=None, input_mask=None, head_mask=None,
|
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
|
||||||
start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None,):
|
start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None,):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -1304,7 +1328,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
target_mapping=target_mapping,
|
target_mapping=target_mapping,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||||
|
|
||||||
|
@ -525,6 +525,19 @@ class CommonTestCases:
|
|||||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||||
|
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
del inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
wte = model.get_input_embeddings()
|
||||||
|
inputs_dict["inputs_embeds"] = wte(input_ids)
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
|
||||||
|
|
||||||
class GPTModelTester(CommonModelTester):
|
class GPTModelTester(CommonModelTester):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user