Merge pull request #2 from erenup/run_multiple_choice_add_doc

Run multiple choice add doc
This commit is contained in:
erenup 2019-09-16 22:39:54 +08:00 committed by GitHub
commit 5a81e79e25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 198 additions and 39 deletions

View File

@ -8,7 +8,8 @@ similar API between the different models.
| [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. | | [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. |
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. | | [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. | | [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. | | [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
## Language model fine-tuning ## Language model fine-tuning
@ -282,6 +283,40 @@ The results are the following:
loss = 0.04755385363816904 loss = 0.04755385363816904
``` ```
##Multiple Choice
Based on the script [`run_multiple_choice.py`]().
#### Fine-tuning on SWAG
Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data
```
#training on 4 tesla V100(16GB) GPUS
export SWAG_DIR=/path/to/swag_data_dir
python ./examples/single_model_scripts/run_multiple_choice.py \
--model_type roberta \
--task_name swag \
--model_name_or_path roberta-base \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $SWAG_DIR \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--max_seq_length 80 \
--output_dir models_bert/swag_base \
--per_gpu_eval_batch_size=16 \
--per_gpu_train_batch_size=16 \
--gradient_accumulation_steps 2 \
--overwrite_output
```
Training with the defined hyper-parameters yields the following results:
```
***** Eval results *****
eval_acc = 0.8338998300509847
eval_loss = 0.44457291918821606
```
## SQuAD ## SQuAD
Based on the script [`run_squad.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py). Based on the script [`run_squad.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py).

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Finetuning the library models for multiple choice (Bert, XLM, XLNet).""" """ Finetuning the library models for multiple choice (Bert, Roberta, XLNet)."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
@ -44,7 +44,7 @@ from utils_multiple_choice import (convert_examples_to_features, processors)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig)), ()) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer), 'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
@ -208,7 +208,6 @@ def train(args, train_dataset, model, tokenizer):
def evaluate(args, model, tokenizer, prefix="", test=False): def evaluate(args, model, tokenizer, prefix="", test=False):
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names = (args.task_name,) eval_task_names = (args.task_name,)
eval_outputs_dirs = (args.output_dir,) eval_outputs_dirs = (args.output_dir,)
@ -259,7 +258,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
result = {"eval_acc": acc, "eval_loss": eval_loss} result = {"eval_acc": acc, "eval_loss": eval_loss}
results.update(result) results.update(result)
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test) + "_eval_results.txt") output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test).lower() + "_eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test))) logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
@ -522,9 +521,9 @@ def main():
if not args.do_train: if not args.do_train:
args.output_dir = args.model_name_or_path args.output_dir = args.model_name_or_path
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: #can not use this to do test!! just for different paras # if args.eval_all_checkpoints: # can not use this to do test!!
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) # checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging # logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" BERT classification fine-tuning: utilities to work with GLUE tasks """ """ BERT multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
@ -38,11 +38,10 @@ class InputExample(object):
"""Constructs a InputExample. """Constructs a InputExample.
Args: Args:
guid: Unique id for the example. example_id: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
sequence tasks, only this sequence must be specified. question: string. The untokenized text of the second sequence (qustion).
text_b: (Optional) string. The untokenized text of the second sequence. endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples. specified for train and dev examples, but not for test examples.
""" """
@ -73,7 +72,7 @@ class InputFeatures(object):
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for multiple choice data sets."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set.""" """Gets a collection of `InputExample`s for the train set."""
@ -84,7 +83,7 @@ class DataProcessor(object):
raise NotImplementedError() raise NotImplementedError()
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set.""" """Gets a collection of `InputExample`s for the test set."""
raise NotImplementedError() raise NotImplementedError()
def get_labels(self): def get_labels(self):
@ -93,7 +92,7 @@ class DataProcessor(object):
class RaceProcessor(DataProcessor): class RaceProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version).""" """Processor for the RACE data set."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
@ -152,13 +151,13 @@ class RaceProcessor(DataProcessor):
InputExample( InputExample(
example_id=race_id, example_id=race_id,
question=question, question=question,
contexts=[article, article, article, article], contexts=[article, article, article, article], # this is not efficient but convenient
endings=[options[0], options[1], options[2], options[3]], endings=[options[0], options[1], options[2], options[3]],
label=truth)) label=truth))
return examples return examples
class SwagProcessor(DataProcessor): class SwagProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version).""" """Processor for the SWAG data set."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
@ -172,9 +171,12 @@ class SwagProcessor(DataProcessor):
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
logger.info("LOOKING AT {} test".format(data_dir)) logger.info("LOOKING AT {} dev".format(data_dir))
raise ValueError(
"For swag testing, the input file does not contain a label column. It can not be tested in current code"
"setting!"
)
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test") return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
return ["0", "1", "2", "3"] return ["0", "1", "2", "3"]
@ -213,7 +215,7 @@ class SwagProcessor(DataProcessor):
class ArcProcessor(DataProcessor): class ArcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version).""" """Processor for the ARC data set (request from allennlp)."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
@ -242,6 +244,7 @@ class ArcProcessor(DataProcessor):
def _create_examples(self, lines, type): def _create_examples(self, lines, type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
#There are two types of labels. They should be normalized
def normalize(truth): def normalize(truth):
if truth in "ABCD": if truth in "ABCD":
return ord(truth) - ord("A") return ord(truth) - ord("A")
@ -256,6 +259,7 @@ class ArcProcessor(DataProcessor):
four_choice = 0 four_choice = 0
five_choice = 0 five_choice = 0
other_choices = 0 other_choices = 0
# we deleted example which has more than or less than four choices
for line in tqdm.tqdm(lines, desc="read arc data"): for line in tqdm.tqdm(lines, desc="read arc data"):
data_raw = json.loads(line.strip("\n")) data_raw = json.loads(line.strip("\n"))
if len(data_raw["question"]["choices"]) == 3: if len(data_raw["question"]["choices"]) == 3:
@ -274,7 +278,6 @@ class ArcProcessor(DataProcessor):
question = question_choices["stem"] question = question_choices["stem"]
id = data_raw["id"] id = data_raw["id"]
options = question_choices["choices"] options = question_choices["choices"]
if len(options) == 4: if len(options) == 4:
examples.append( examples.append(
InputExample( InputExample(
@ -328,13 +331,16 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
tokens_a = tokenizer.tokenize(context) tokens_a = tokenizer.tokenize(context)
tokens_b = None tokens_b = None
if example.question.find("_") != -1: if example.question.find("_") != -1:
#this is for cloze question
tokens_b = tokenizer.tokenize(example.question.replace("_", ending)) tokens_b = tokenizer.tokenize(example.question.replace("_", ending))
else: else:
tokens_b = tokenizer.tokenize(example.question) tokens_b = tokenizer.tokenize(example.question + " " + ending)
tokens_b += [sep_token] # you can add seq token between quesiotn and ending. This does not make too much difference.
if sep_token_extra: # tokens_b = tokenizer.tokenize(example.question)
tokens_b += [sep_token] # tokens_b += [sep_token]
tokens_b += tokenizer.tokenize(ending) # if sep_token_extra:
# tokens_b += [sep_token]
# tokens_b += tokenizer.tokenize(ending)
special_tokens_count = 4 if sep_token_extra else 3 special_tokens_count = 4 if sep_token_extra else 3
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count) _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
@ -427,15 +433,20 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
# one token at a time. This makes more sense than truncating an equal percent # one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token # of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence. # that's truncated likely contains more information than a longer sequence.
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
# length or only pop from context
while True: while True:
total_length = len(tokens_a) + len(tokens_b) total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length: if total_length <= max_length:
break break
# if len(tokens_a) > len(tokens_b): if len(tokens_a) > len(tokens_b):
# tokens_a.pop() tokens_a.pop()
# else: else:
# tokens_b.pop() logger.info('Attention! you are removing from token_b (swag task is ok). '
tokens_a.pop() 'If you are training ARC and RACE (you are poping question + options), '
'you need to try to use a bigger max seq length!')
tokens_b.pop()
processors = { processors = {

View File

@ -296,7 +296,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
Examples:: Examples::
tokenizer = RoertaTokenizer.from_pretrained('roberta-base') tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base') model = RobertaForSequenceClassification.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
@ -338,8 +338,75 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions) return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings("""Roberta Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaForMultipleChoice(BertPreTrainedModel): class RobertaForMultipleChoice(BertPreTrainedModel):
r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
To match pre-training, RoBerta input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] [SEP] no it is not . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0``
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
config_class = RobertaConfig config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta" base_model_prefix = "roberta"
@ -351,7 +418,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None):

View File

@ -1006,9 +1006,56 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return outputs # return (loss), logits, mems, (hidden states), (attentions) return outputs # return (loss), logits, mems, (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
class XLNetForMultipleChoice(XLNetPreTrainedModel): class XLNetForMultipleChoice(XLNetPreTrainedModel):
r""" r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to scores.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = XLNetForMultipleChoice.from_pretrained('xlnet-base-cased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
@ -1018,7 +1065,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, 1) self.logits_proj = nn.Linear(config.d_model, 1)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, mems=None, perm_mask=None, target_mapping=None,
@ -1105,7 +1152,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
Examples:: Examples::
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased') model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1]) start_positions = torch.tensor([1])