diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index fecf1e4de87..8c92241fa22 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -1095,12 +1095,16 @@ class BertForQuestionAnswering(BertPreTrainedModel): Examples:: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - model = BertForQuestionAnswering.from_pretrained('bert-base-uncased') - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 - start_positions = torch.tensor([1]) - end_positions = torch.tensor([3]) - outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) - loss, start_scores, end_scores = outputs[:2] + model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad') + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + input_text = "[CLS] " + question + " [SEP] " + text + " [SEP]" + input_ids = tokenizer.encode(input_text) + token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))] + start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids])) + all_tokens = tokenizer.convert_ids_to_tokens(input_ids) + print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])) + # a nice puppet + """ def __init__(self, config):