mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Merge pull request #2 from erenup/run_multiple_choice_add_doc
Run multiple choice add doc
This commit is contained in:
commit
5a81e79e25
@ -9,6 +9,7 @@ similar API between the different models.
|
||||
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
|
||||
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
|
||||
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
|
||||
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
|
||||
|
||||
## Language model fine-tuning
|
||||
|
||||
@ -282,6 +283,40 @@ The results are the following:
|
||||
loss = 0.04755385363816904
|
||||
```
|
||||
|
||||
##Multiple Choice
|
||||
|
||||
Based on the script [`run_multiple_choice.py`]().
|
||||
|
||||
#### Fine-tuning on SWAG
|
||||
Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data
|
||||
|
||||
```
|
||||
#training on 4 tesla V100(16GB) GPUS
|
||||
export SWAG_DIR=/path/to/swag_data_dir
|
||||
python ./examples/single_model_scripts/run_multiple_choice.py \
|
||||
--model_type roberta \
|
||||
--task_name swag \
|
||||
--model_name_or_path roberta-base \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $SWAG_DIR \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3 \
|
||||
--max_seq_length 80 \
|
||||
--output_dir models_bert/swag_base \
|
||||
--per_gpu_eval_batch_size=16 \
|
||||
--per_gpu_train_batch_size=16 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--overwrite_output
|
||||
```
|
||||
Training with the defined hyper-parameters yields the following results:
|
||||
```
|
||||
***** Eval results *****
|
||||
eval_acc = 0.8338998300509847
|
||||
eval_loss = 0.44457291918821606
|
||||
```
|
||||
|
||||
## SQuAD
|
||||
|
||||
Based on the script [`run_squad.py`](https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py).
|
||||
|
@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for multiple choice (Bert, XLM, XLNet)."""
|
||||
""" Finetuning the library models for multiple choice (Bert, Roberta, XLNet)."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
@ -44,7 +44,7 @@ from utils_multiple_choice import (convert_examples_to_features, processors)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig)), ())
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||
@ -208,7 +208,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir,)
|
||||
|
||||
@ -259,7 +258,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||
result = {"eval_acc": acc, "eval_loss": eval_loss}
|
||||
results.update(result)
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test) + "_eval_results.txt")
|
||||
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test).lower() + "_eval_results.txt")
|
||||
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
|
||||
@ -522,9 +521,9 @@ def main():
|
||||
if not args.do_train:
|
||||
args.output_dir = args.model_name_or_path
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints: #can not use this to do test!! just for different paras
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
# if args.eval_all_checkpoints: # can not use this to do test!!
|
||||
# checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
# logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
|
@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BERT classification fine-tuning: utilities to work with GLUE tasks """
|
||||
""" BERT multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
@ -38,11 +38,10 @@ class InputExample(object):
|
||||
"""Constructs a InputExample.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
example_id: Unique id for the example.
|
||||
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
|
||||
question: string. The untokenized text of the second sequence (qustion).
|
||||
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
@ -73,7 +72,7 @@ class InputFeatures(object):
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
"""Base class for data converters for sequence classification data sets."""
|
||||
"""Base class for data converters for multiple choice data sets."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
@ -84,7 +83,7 @@ class DataProcessor(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
"""Gets a collection of `InputExample`s for the test set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
@ -93,7 +92,7 @@ class DataProcessor(object):
|
||||
|
||||
|
||||
class RaceProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
"""Processor for the RACE data set."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
@ -152,13 +151,13 @@ class RaceProcessor(DataProcessor):
|
||||
InputExample(
|
||||
example_id=race_id,
|
||||
question=question,
|
||||
contexts=[article, article, article, article],
|
||||
contexts=[article, article, article, article], # this is not efficient but convenient
|
||||
endings=[options[0], options[1], options[2], options[3]],
|
||||
label=truth))
|
||||
return examples
|
||||
|
||||
class SwagProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
"""Processor for the SWAG data set."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
@ -172,9 +171,12 @@ class SwagProcessor(DataProcessor):
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} test".format(data_dir))
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
raise ValueError(
|
||||
"For swag testing, the input file does not contain a label column. It can not be tested in current code"
|
||||
"setting!"
|
||||
)
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3"]
|
||||
@ -213,7 +215,7 @@ class SwagProcessor(DataProcessor):
|
||||
|
||||
|
||||
class ArcProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
"""Processor for the ARC data set (request from allennlp)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
@ -242,6 +244,7 @@ class ArcProcessor(DataProcessor):
|
||||
def _create_examples(self, lines, type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
|
||||
#There are two types of labels. They should be normalized
|
||||
def normalize(truth):
|
||||
if truth in "ABCD":
|
||||
return ord(truth) - ord("A")
|
||||
@ -256,6 +259,7 @@ class ArcProcessor(DataProcessor):
|
||||
four_choice = 0
|
||||
five_choice = 0
|
||||
other_choices = 0
|
||||
# we deleted example which has more than or less than four choices
|
||||
for line in tqdm.tqdm(lines, desc="read arc data"):
|
||||
data_raw = json.loads(line.strip("\n"))
|
||||
if len(data_raw["question"]["choices"]) == 3:
|
||||
@ -274,7 +278,6 @@ class ArcProcessor(DataProcessor):
|
||||
question = question_choices["stem"]
|
||||
id = data_raw["id"]
|
||||
options = question_choices["choices"]
|
||||
|
||||
if len(options) == 4:
|
||||
examples.append(
|
||||
InputExample(
|
||||
@ -328,13 +331,16 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
tokens_a = tokenizer.tokenize(context)
|
||||
tokens_b = None
|
||||
if example.question.find("_") != -1:
|
||||
#this is for cloze question
|
||||
tokens_b = tokenizer.tokenize(example.question.replace("_", ending))
|
||||
else:
|
||||
tokens_b = tokenizer.tokenize(example.question)
|
||||
tokens_b += [sep_token]
|
||||
if sep_token_extra:
|
||||
tokens_b += [sep_token]
|
||||
tokens_b += tokenizer.tokenize(ending)
|
||||
tokens_b = tokenizer.tokenize(example.question + " " + ending)
|
||||
# you can add seq token between quesiotn and ending. This does not make too much difference.
|
||||
# tokens_b = tokenizer.tokenize(example.question)
|
||||
# tokens_b += [sep_token]
|
||||
# if sep_token_extra:
|
||||
# tokens_b += [sep_token]
|
||||
# tokens_b += tokenizer.tokenize(ending)
|
||||
|
||||
special_tokens_count = 4 if sep_token_extra else 3
|
||||
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
|
||||
@ -427,15 +433,20 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
|
||||
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
|
||||
# length or only pop from context
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
# if len(tokens_a) > len(tokens_b):
|
||||
# tokens_a.pop()
|
||||
# else:
|
||||
# tokens_b.pop()
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
logger.info('Attention! you are removing from token_b (swag task is ok). '
|
||||
'If you are training ARC and RACE (you are poping question + options), '
|
||||
'you need to try to use a bigger max seq length!')
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
processors = {
|
||||
|
@ -296,7 +296,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = RoertaTokenizer.from_pretrained('roberta-base')
|
||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
@ -338,8 +338,75 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Roberta Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||
class RobertaForMultipleChoice(BertPreTrainedModel):
|
||||
r"""
|
||||
Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
To match pre-training, RoBerta input sequence should be formatted with [CLS] and [SEP] tokens as follows:
|
||||
|
||||
(a) For sequence pairs:
|
||||
|
||||
``tokens: [CLS] is this jack ##son ##ville ? [SEP] [SEP] no it is not . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
|
||||
|
||||
(b) For single sequences:
|
||||
|
||||
``tokens: [CLS] the dog is hairy . [SEP]``
|
||||
|
||||
``token_type_ids: 0 0 0 0 0 0 0``
|
||||
|
||||
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
|
||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the multiple choice classification loss.
|
||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above)
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss.
|
||||
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above).
|
||||
Classification scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
|
||||
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, classification_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
@ -351,7 +418,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||
position_ids=None, head_mask=None):
|
||||
|
@ -1006,9 +1006,56 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
|
||||
return outputs # return (loss), logits, mems, (hidden states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
|
||||
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
||||
r"""
|
||||
Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to scores.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to score.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the multiple choice classification loss.
|
||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above)
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss.
|
||||
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above).
|
||||
Classification scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
|
||||
model = XLNetForMultipleChoice.from_pretrained('xlnet-base-cased')
|
||||
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, classification_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
@ -1018,7 +1065,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
self.logits_proj = nn.Linear(config.d_model, 1)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None,
|
||||
@ -1105,7 +1152,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||
model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
start_positions = torch.tensor([1])
|
||||
|
Loading…
Reference in New Issue
Block a user