Merge pull request #1502 from jeffxtang/master

the working example code to use BertForQuestionAnswering
This commit is contained in:
Thomas Wolf 2019-10-14 16:14:52 +02:00 committed by GitHub
commit f62f992cf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1095,12 +1095,16 @@ class BertForQuestionAnswering(BertPreTrainedModel):
Examples:: Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased') model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
start_positions = torch.tensor([1]) input_text = "[CLS] " + question + " [SEP] " + text + " [SEP]"
end_positions = torch.tensor([3]) input_ids = tokenizer.encode(input_text)
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))]
loss, start_scores, end_scores = outputs[:2] start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
# a nice puppet
""" """
def __init__(self, config): def __init__(self, config):