Merge pull request #1695 from huggingface/models_inputs_embeds

model forwards can take an inputs_embeds param
This commit is contained in:
Julien Chaumond 2019-11-05 09:55:28 -05:00 committed by GitHub
commit 7daacf00df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 387 additions and 158 deletions

View File

@ -255,6 +255,10 @@ XXX_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",

View File

@ -238,6 +238,10 @@ XXX_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
@ -295,7 +299,7 @@ class XxxModel(XxxPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
@ -449,14 +453,15 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
pooled_output = outputs[1]
@ -520,14 +525,15 @@ class XxxForTokenClassification(XxxPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]
@ -603,14 +609,15 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]

View File

@ -158,19 +158,26 @@ class BertEmbeddings(nn.Module):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
words_embeddings = self.word_embeddings(input_ids)
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
@ -550,6 +557,10 @@ BERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
is configured as a decoder.
@ -615,8 +626,8 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
""" Forward pass on the Model.
The model can behave as an encoder (with only self-attention) as well
@ -632,12 +643,23 @@ class BertModel(BertPreTrainedModel):
https://arxiv.org/abs/1706.03762
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.ones(input_shape)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones_like(input_ids)
encoder_attention_mask = torch.ones(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
@ -649,8 +671,8 @@ class BertModel(BertPreTrainedModel):
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if attention_mask.dim() == 2:
if self.config.is_decoder:
batch_size, seq_length = input_ids.size()
seq_ids = torch.arange(seq_length, device=input_ids.device)
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
@ -689,7 +711,7 @@ class BertModel(BertPreTrainedModel):
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
encoder_outputs = self.encoder(embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
@ -754,14 +776,15 @@ class BertForPreTraining(BertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None, next_sentence_label=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
@ -829,7 +852,7 @@ class BertForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
outputs = self.bert(input_ids,
@ -837,6 +860,7 @@ class BertForMaskedLM(BertPreTrainedModel):
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask)
@ -908,14 +932,15 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
next_sentence_label=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
pooled_output = outputs[1]
@ -975,14 +1000,15 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
pooled_output = outputs[1]
@ -1049,8 +1075,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
num_choices = input_ids.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1))
@ -1062,7 +1088,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
pooled_output = outputs[1]
@ -1123,14 +1150,15 @@ class BertForTokenClassification(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]
@ -1207,14 +1235,15 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]

View File

@ -236,6 +236,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
@ -302,17 +306,26 @@ class CTRLModel(CTRLPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask.
if attention_mask is not None:
@ -354,9 +367,10 @@ class CTRLModel(CTRLPreTrainedModel):
token_type_embeds = 0
position_ids = position_ids.view(-1, input_shape[-1])
inputs_embeds = self.w(input_ids)
if inputs_embeds is None:
inputs_embeds = self.w(input_ids)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_ids.shape[-1]
seq_len = input_shape[-1]
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
inputs_embeds *= np.sqrt(self.d_model_size)
@ -455,14 +469,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None):
transformer_outputs = self.transformer(input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0]

View File

@ -387,6 +387,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
@ -436,9 +440,18 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.transformer.layer[layer].attention.prune_heads(heads)
def forward(self,
input_ids, attention_mask=None, head_mask=None):
input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
attention_mask = torch.ones(input_shape) # (bs, seq_length)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -455,8 +468,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(x=embedding_output,
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(x=inputs_embeds,
attn_mask=attention_mask,
head_mask=head_mask)
hidden_state = tfmr_output[0]
@ -514,10 +528,11 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def get_output_embeddings(self):
return self.vocab_projector
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, masked_lm_labels=None):
dlbrt_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
@ -578,10 +593,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
@ -652,10 +668,11 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, head_mask=None, start_positions=None, end_positions=None):
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)

View File

@ -313,6 +313,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
@ -370,9 +374,17 @@ class GPT2Model(GPT2PreTrainedModel):
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
@ -384,8 +396,9 @@ class GPT2Model(GPT2PreTrainedModel):
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask.
if attention_mask is not None:
@ -419,7 +432,8 @@ class GPT2Model(GPT2PreTrainedModel):
else:
head_mask = [None] * self.config.n_layer
inputs_embeds = self.wte(input_ids)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
@ -520,14 +534,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None):
transformer_outputs = self.transformer(input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@ -623,14 +638,15 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
mc_token_ids=None, lm_labels=None, mc_labels=None):
transformer_outputs = self.transformer(input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0]

View File

@ -322,6 +322,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
@ -373,14 +377,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if position_ids is None:
# This was used when we had a single embedding matrice from position and token embeddings
# start = self.config.vocab_size + self.config.n_special
# end = start + input_ids.size(-1)
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Code is different from when we had a single embedding matrice from position and token embeddings
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask.
if attention_mask is not None:
@ -413,11 +425,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
else:
head_mask = [None] * self.config.n_layer
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.tokens_embed(input_ids)
if inputs_embeds is None:
inputs_embeds = self.tokens_embed(input_ids)
position_embeds = self.positions_embed(position_ids)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
@ -495,13 +504,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@ -587,13 +597,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
mc_token_ids=None, lm_labels=None, mc_labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)

View File

@ -48,16 +48,24 @@ class RobertaEmbeddings(BertEmbeddings):
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size,
padding_idx=self.padding_idx)
def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
# cf. fairseq's `utils.make_positions`
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
return super(RobertaEmbeddings, self).forward(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids)
position_ids=position_ids,
inputs_embeds=inputs_embeds)
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
@ -126,6 +134,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
@ -222,13 +234,14 @@ class RobertaForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
masked_lm_labels=None):
outputs = self.roberta(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
@ -309,13 +322,14 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None):
outputs = self.roberta(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
@ -372,6 +386,10 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
@ -415,8 +433,8 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None, inputs_embeds=None):
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
@ -487,14 +505,15 @@ class RobertaForTokenClassification(BertPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.roberta(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]

View File

@ -616,6 +616,10 @@ BERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",

View File

@ -374,6 +374,10 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",

View File

@ -508,6 +508,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",

View File

@ -408,6 +408,10 @@ GPT2_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",

View File

@ -389,6 +389,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare OpenAI GPT transformer model outputing raw hidden-states without any specific head on top.",

View File

@ -157,6 +157,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",

View File

@ -626,6 +626,10 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",

View File

@ -35,7 +35,7 @@ class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.

View File

@ -530,6 +530,10 @@ XLM_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",

View File

@ -762,6 +762,10 @@ XLNET_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",

View File

@ -553,6 +553,10 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
@ -657,12 +661,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
logger.info("Head pruning is not implemented for Transformer-XL model")
pass
def init_mems(self, data):
def init_mems(self, bsz):
if self.mem_len > 0:
mems = []
param = next(self.parameters())
for i in range(self.n_layer):
empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model,
empty = torch.zeros(self.mem_len, bsz, self.config.d_model,
dtype=param.dtype, device=param.device)
mems.append(empty)
@ -693,15 +697,22 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return new_mems
def forward(self, input_ids, mems=None, head_mask=None):
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids = input_ids.transpose(0, 1).contiguous()
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = input_ids.transpose(0, 1).contiguous()
qlen, bsz = input_ids.size()
elif inputs_embeds is not None:
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if mems is None:
mems = self.init_mems(input_ids)
qlen, bsz = input_ids.size()
mems = self.init_mems(bsz)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -718,7 +729,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else:
head_mask = [None] * self.n_layer
word_emb = self.word_emb(input_ids)
if inputs_embeds is not None:
word_emb = inputs_embeds
else:
word_emb = self.word_emb(input_ids)
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
@ -860,14 +874,18 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len)
def init_mems(self, data):
return self.transformer.init_mems(data)
def init_mems(self, bsz):
return self.transformer.init_mems(bsz)
def forward(self, input_ids, mems=None, head_mask=None, labels=None):
bsz = input_ids.size(0)
tgt_len = input_ids.size(1)
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None):
if input_ids is not None:
bsz, tgt_len = input_ids.size(0), input_ids.size(1)
elif inputs_embeds is not None:
bsz, tgt_len = inputs_embeds.size(0), inputs_embeds.size(1)
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask)
transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask, inputs_embeds=inputs_embeds)
last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:]

View File

@ -53,7 +53,7 @@ class PreTrainedModel(nn.Module):
r""" Base class for all models.
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.

View File

@ -311,6 +311,10 @@ XLM_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLM Model transformer outputting raw hidden-states without any specific head on top.",
@ -421,14 +425,21 @@ class XLMModel(XLMPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.attentions[layer].prune_heads(heads)
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None): # removed: src_enc=None, src_len=None
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None): # removed: src_enc=None, src_len=None
if input_ids is not None:
bs, slen = input_ids.size()
else:
bs, slen = inputs_embeds.size()[:-1]
if lengths is None:
lengths = (input_ids != self.pad_index).sum(dim=1).long()
if input_ids is not None:
lengths = (input_ids != self.pad_index).sum(dim=1).long()
else:
lengths = torch.LongTensor([slen]*bs)
# mask = input_ids != self.pad_index
# check inputs
bs, slen = input_ids.size()
assert lengths.size(0) == bs
assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
@ -442,10 +453,12 @@ class XLMModel(XLMPreTrainedModel):
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# position_ids
if position_ids is None:
position_ids = input_ids.new((slen,)).long()
position_ids = torch.arange(slen, out=position_ids).unsqueeze(0)
position_ids = torch.arange(slen, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
else:
assert position_ids.size() == (bs, slen) # (slen, bs)
# position_ids = position_ids.transpose(0, 1)
@ -471,7 +484,7 @@ class XLMModel(XLMPreTrainedModel):
head_mask = [None] * self.n_layers
# do not recompute cached elements
if cache is not None:
if cache is not None and input_ids is not None:
_slen = slen - cache['slen']
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
@ -481,8 +494,10 @@ class XLMModel(XLMPreTrainedModel):
attn_mask = attn_mask[:, -_slen:]
# embeddings
tensor = self.embeddings(input_ids)
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
@ -624,8 +639,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def get_output_embeddings(self):
return self.pred_layer.proj
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
@ -633,7 +648,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0]
outputs = self.pred_layer(output, labels)
@ -685,8 +701,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
@ -694,7 +710,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
@ -768,8 +785,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None):
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
@ -777,7 +794,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = transformer_outputs[0]
@ -863,8 +881,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None,
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None,
is_impossible=None, cls_index=None, p_mask=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
@ -873,7 +891,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0]

View File

@ -558,6 +558,10 @@ XLNET_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings("The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
@ -712,19 +716,29 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None):
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None):
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
input_ids = input_ids.transpose(0, 1).contiguous()
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = input_ids.transpose(0, 1).contiguous()
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
elif inputs_embeds is not None:
inputs_embeds.transpose(0, 1).contiguous()
qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen
@ -777,7 +791,10 @@ class XLNetModel(XLNetPreTrainedModel):
non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(input_ids)
if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k)
if target_mapping is not None:
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
@ -924,8 +941,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
@ -933,7 +950,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
logits = self.lm_loss(transformer_outputs[0])
@ -998,8 +1016,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, labels=None):
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
@ -1007,7 +1025,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
output = transformer_outputs[0]
output = self.sequence_summary(output)
@ -1049,6 +1068,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
@ -1093,9 +1116,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
def forward(self, input_ids=None, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None,
labels=None, head_mask=None):
labels=None, head_mask=None, inputs_embeds=None):
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
@ -1106,7 +1129,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids,
input_mask=flat_input_mask, attention_mask=flat_attention_mask,
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
head_mask=head_mask)
head_mask=head_mask, inputs_embeds=inputs_embeds)
output = transformer_outputs[0]
@ -1178,8 +1201,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None):
outputs = self.transformer(input_ids,
@ -1189,7 +1212,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]
@ -1294,8 +1318,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.init_weights()
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None,
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None,
start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None,):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
@ -1304,7 +1328,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask)
head_mask=head_mask,
inputs_embeds=inputs_embeds)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask)

View File

@ -525,6 +525,19 @@ class CommonTestCases:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
del inputs_dict["input_ids"]
for model_class in self.all_model_classes:
model = model_class(config)
model.eval()
wte = model.get_input_embeddings()
inputs_dict["inputs_embeds"] = wte(input_ids)
outputs = model(**inputs_dict)
class GPTModelTester(CommonModelTester):