Remove token_type_ids for compatibility with DistilBert

This commit is contained in:
Morgan Funtowicz 2019-12-09 18:34:58 +01:00
parent fe0f552e00
commit a7d3794a29

View File

@ -20,7 +20,7 @@ from typing import Union, Optional, Tuple, List, Dict
import numpy as np
from transformers import is_tf_available, logger, AutoTokenizer, PreTrainedTokenizer, is_torch_available
from transformers import is_tf_available, is_torch_available, logger, AutoTokenizer, PreTrainedTokenizer
if is_tf_available():
from transformers import TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering
@ -154,6 +154,8 @@ class QuestionAnsweringPipeline(Pipeline):
return_attention_masks=True, return_input_lengths=False
)
token_type_ids = inputs.pop('token_type_ids')
if is_tf_available():
# TODO trace model
start, end = self.model(inputs)
@ -167,7 +169,7 @@ class QuestionAnsweringPipeline(Pipeline):
answers = []
for i in range(len(texts)):
context_idx = inputs['token_type_ids'][i] == 1
context_idx = token_type_ids[i] == 1
start_, end_ = start[i, context_idx], end[i, context_idx]
# Normalize logits and spans to retrieve the answer