mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
mc_token_mask => mc_token_ids
This commit is contained in:
parent
f4a07a392c
commit
1320e4ec0c
@ -64,7 +64,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
|
||||
for dataset in encoded_datasets:
|
||||
n_batch = len(dataset)
|
||||
input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64)
|
||||
mc_token_mask = np.zeros((n_batch, 2, input_len), dtype=np.int64)
|
||||
mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64)
|
||||
lm_labels = np.full((n_batch, 2, input_len), fill_value=-1, dtype=np.int64)
|
||||
mc_labels = np.zeros((n_batch,), dtype=np.int64)
|
||||
for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
|
||||
@ -72,12 +72,12 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
|
||||
with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token]
|
||||
input_ids[i, 0, :len(with_cont1)] = with_cont1
|
||||
input_ids[i, 1, :len(with_cont2)] = with_cont2
|
||||
mc_token_mask[i, 0, len(with_cont1) - 1] = 1
|
||||
mc_token_mask[i, 1, len(with_cont2) - 1] = 1
|
||||
mc_token_ids[i, 0] = len(with_cont1) - 1
|
||||
mc_token_ids[i, 1] = len(with_cont2) - 1
|
||||
lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:]
|
||||
lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:]
|
||||
mc_labels[i] = mc_label
|
||||
all_inputs = (input_ids, mc_token_mask, lm_labels, mc_labels)
|
||||
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
|
||||
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
|
||||
return tensor_datasets
|
||||
|
||||
@ -197,8 +197,8 @@ def main():
|
||||
tqdm_bar = tqdm(train_dataloader, desc="Training")
|
||||
for step, batch in enumerate(tqdm_bar):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, mc_token_mask, lm_labels, mc_labels = batch
|
||||
losses = model(input_ids, mc_token_mask, lm_labels, mc_labels)
|
||||
input_ids, mc_token_ids, lm_labels, mc_labels = batch
|
||||
losses = model(input_ids, mc_token_ids, lm_labels, mc_labels)
|
||||
loss = args.lm_coef * losses[0] + losses[1]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@ -226,10 +226,10 @@ def main():
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, mc_token_mask, lm_labels, mc_labels = batch
|
||||
input_ids, mc_token_ids, lm_labels, mc_labels = batch
|
||||
with torch.no_grad():
|
||||
_, mc_loss = model(input_ids, mc_token_mask, lm_labels, mc_labels)
|
||||
_, mc_logits = model(input_ids, mc_token_mask)
|
||||
_, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels)
|
||||
_, mc_logits = model(input_ids, mc_token_ids)
|
||||
|
||||
mc_logits = mc_logits.detach().cpu().numpy()
|
||||
mc_labels = mc_labels.to('cpu').numpy()
|
||||
|
@ -366,23 +366,16 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
|
||||
nn.init.normal_(self.linear.weight, std=0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, hidden_states, mc_token_mask):
|
||||
def forward(self, hidden_states, mc_token_ids):
|
||||
# Classification logits
|
||||
# hidden_states = hidden_states.view(-1, self.n_embd)
|
||||
# mc_token_mask = mc_token_mask.view(-1, 1).expand_as(hidden_states)
|
||||
mc_token_mask = mc_token_mask.float()
|
||||
multiple_choice_h = hidden_states * mc_token_mask.unsqueeze(-1)
|
||||
multiple_choice_h = multiple_choice_h.sum(dim=-2)
|
||||
# flat = x[..., 0].contiguous().view(-1)
|
||||
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
|
||||
# multiple_choice_h = multiple_choice_h.view(-1, x.size(1), self.n_embd, 1)
|
||||
# # This double transposition is there to replicate the behavior
|
||||
# # of the noise_shape argument in the tensorflow
|
||||
# # implementation. For more details, see
|
||||
# # https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11
|
||||
# multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
|
||||
# multiple_choice_h = multiple_choice_h.contiguous().view(-1, self.n_embd)
|
||||
# hidden_state (bsz, num_choices, seq_length, hidden_size)
|
||||
# mc_token_ids (bsz, num_choices)
|
||||
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
|
||||
# (bsz, num_choices, 1, hidden_size)
|
||||
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
|
||||
# (bsz, num_choices, hidden_size)
|
||||
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
|
||||
# (bsz, num_choices)
|
||||
return multiple_choice_logits
|
||||
|
||||
|
||||
@ -727,7 +720,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
|
||||
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
|
||||
"""OpenAI GPT model with a Language Modeling and a Multiple Choice head ("Improving Language Understanding by Generative Pre-Training").
|
||||
|
||||
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
|
||||
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
|
||||
@ -750,8 +743,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
config: a OpenAIGPTConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
|
||||
indices selected in the range [0, total_tokens_embeddings[
|
||||
`mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
|
||||
which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
@ -775,13 +770,13 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
mc_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length)
|
||||
mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
|
||||
|
||||
config = modeling_openai.OpenAIGPTConfig()
|
||||
|
||||
model = modeling_openai.OpenAIGPTLMHeadModel(config)
|
||||
lm_logits, multiple_choice_logits = model(input_ids, mc_token_mask)
|
||||
lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids)
|
||||
```
|
||||
"""
|
||||
|
||||
@ -799,10 +794,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
||||
|
||||
def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_mask)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
|
@ -89,11 +89,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
|
||||
mc_labels = None
|
||||
lm_labels = None
|
||||
mc_token_mask = None
|
||||
mc_token_ids = None
|
||||
if self.use_labels:
|
||||
mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
|
||||
mc_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float()
|
||||
mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length).float()
|
||||
|
||||
config = OpenAIGPTConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
@ -109,10 +109,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return (config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_mask)
|
||||
mc_labels, lm_labels, mc_token_ids)
|
||||
|
||||
def create_openai_model(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_mask):
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = OpenAIGPTModel(config)
|
||||
model.eval()
|
||||
hidden_states = model(input_ids, position_ids, token_type_ids)
|
||||
@ -128,7 +128,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
|
||||
|
||||
def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_mask):
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = OpenAIGPTLMHeadModel(config)
|
||||
model.eval()
|
||||
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
|
||||
@ -151,13 +151,13 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
[])
|
||||
|
||||
def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_mask):
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = OpenAIGPTDoubleHeadsModel(config)
|
||||
model.eval()
|
||||
loss = model(input_ids, mc_token_mask,
|
||||
loss = model(input_ids, mc_token_ids,
|
||||
lm_labels=lm_labels, mc_labels=mc_labels,
|
||||
token_type_ids=token_type_ids, position_ids=position_ids)
|
||||
lm_logits, mc_logits = model(input_ids, mc_token_mask, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
lm_logits, mc_logits = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
outputs = {
|
||||
"loss": loss,
|
||||
"lm_logits": lm_logits,
|
||||
|
Loading…
Reference in New Issue
Block a user