changing is_regression to unified API

This commit is contained in:
thomwolf 2019-06-26 09:54:05 +02:00
parent e55d4c4ede
commit 092dacfd62
4 changed files with 49 additions and 31 deletions

View File

@ -591,3 +591,15 @@ output_modes = {
"rte": "classification",
"wnli": "classification",
}
GLUE_TASKS_NUM_LABELS = {
"cola": 2,
"mnli": 3,
"mrpc": 2,
"sst-2": 2,
"sts-b": 1,
"qqp": 2,
"qnli": 2,
"rte": 2,
"wnli": 2,
}

View File

@ -28,16 +28,16 @@ from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
XLNetForSequenceClassification,
load_tf_weights_in_xlnet)
GLUE_TASKS = {
"cola": "classification",
"mnli": "classification",
"mrpc": "classification",
"sst-2": "classification",
"sts-b": "regression",
"qqp": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
GLUE_TASKS_NUM_LABELS = {
"cola": 2,
"mnli": 3,
"mrpc": 2,
"sst-2": 2,
"sts-b": 1,
"qqp": 2,
"qnli": 2,
"rte": 2,
"wnli": 2,
}
@ -46,9 +46,9 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
config = XLNetConfig.from_json_file(bert_config_file)
finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
if finetuning_task in GLUE_TASKS:
if finetuning_task in GLUE_TASKS_NUM_LABELS:
print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
model = XLNetForSequenceClassification(config, is_regression=bool(GLUE_TASKS[finetuning_task] == "regression"))
model = XLNetForSequenceClassification(config, num_labels=GLUE_TASKS_NUM_LABELS[finetuning_task])
elif 'squad' in finetuning_task:
model = XLNetForQuestionAnswering(config)
else:

View File

@ -27,7 +27,7 @@ from io import open
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
@ -1196,8 +1196,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = self.classifier(pooled_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
elif self.output_attentions:
return all_attentions, logits

View File

@ -1175,7 +1175,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
target=None, output_all_encoded_layers=True, head_mask=None):
labels=None, output_all_encoded_layers=True, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
@ -1212,11 +1212,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
logits = self.lm_loss(output)
if target is not None:
if labels is not None:
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)),
target.view(-1))
labels.view(-1))
return loss, new_mems
# if self.output_attentions:
@ -1305,13 +1305,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Outputs: Tuple of (logits or loss, mems)
`logits or loss`:
if target is None:
if labels is None:
Token logits with shape [batch_size, sequence_length]
else:
CrossEntropy loss with the targets
`new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
Example usage:
```python
@ -1328,13 +1328,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
```
"""
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
is_regression=False, output_attentions=False, keep_multihead_output=False):
output_attentions=False, keep_multihead_output=False):
super(XLNetForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.attn_type = config.attn_type
self.same_length = config.same_length
self.summary_type = summary_type
self.is_regression = is_regression
self.num_labels = num_labels
self.transformer = XLNetModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
@ -1342,12 +1342,12 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
use_proj=use_proj, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.logits_proj = nn.Linear(config.d_model, num_labels if not is_regression else 1)
self.logits_proj = nn.Linear(config.d_model, num_labels)
self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
target=None, output_all_encoded_layers=True, head_mask=None):
labels=None, output_all_encoded_layers=True, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
@ -1376,19 +1376,20 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Set to None during finetuning.
"""
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
output = self.sequence_summary(output)
logits = self.logits_proj(output)
if target is not None:
if self.is_regression:
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), target.view(-1))
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)), target.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, new_mems
# if self.output_attentions: