mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Remove token_type_ids for compatibility with DistilBert
This commit is contained in:
parent
fe0f552e00
commit
a7d3794a29
@ -20,7 +20,7 @@ from typing import Union, Optional, Tuple, List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_tf_available, logger, AutoTokenizer, PreTrainedTokenizer, is_torch_available
|
||||
from transformers import is_tf_available, is_torch_available, logger, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
if is_tf_available():
|
||||
from transformers import TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering
|
||||
@ -154,6 +154,8 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
return_attention_masks=True, return_input_lengths=False
|
||||
)
|
||||
|
||||
token_type_ids = inputs.pop('token_type_ids')
|
||||
|
||||
if is_tf_available():
|
||||
# TODO trace model
|
||||
start, end = self.model(inputs)
|
||||
@ -167,7 +169,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
|
||||
answers = []
|
||||
for i in range(len(texts)):
|
||||
context_idx = inputs['token_type_ids'][i] == 1
|
||||
context_idx = token_type_ids[i] == 1
|
||||
start_, end_ = start[i, context_idx], end[i, context_idx]
|
||||
|
||||
# Normalize logits and spans to retrieve the answer
|
||||
|
Loading…
Reference in New Issue
Block a user