mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Unnest QuestionAnsweringArgumentHandler
This commit is contained in:
parent
e347725d8c
commit
0c88c856d5
@ -333,6 +333,63 @@ class NerPipeline(Pipeline):
|
||||
return answers
|
||||
|
||||
|
||||
class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
"""
|
||||
QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped
|
||||
to internal SquadExample / SquadFeature structures.
|
||||
|
||||
QuestionAnsweringArgumentHandler manages all the possible to create SquadExample from the command-line supplied
|
||||
arguments.
|
||||
"""
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
|
||||
if args is not None and len(args) > 0:
|
||||
if len(args) == 1:
|
||||
kwargs['X'] = args[0]
|
||||
else:
|
||||
kwargs['X'] = list(args)
|
||||
|
||||
# Generic compatibility with sklearn and Keras
|
||||
# Batched data
|
||||
if 'X' in kwargs or 'data' in kwargs:
|
||||
data = kwargs['X'] if 'X' in kwargs else kwargs['data']
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict):
|
||||
if any(k not in item for k in ['question', 'context']):
|
||||
raise KeyError('You need to provide a dictionary with keys {question:..., context:...}')
|
||||
data[i] = QuestionAnsweringPipeline.create_sample(**item)
|
||||
|
||||
elif isinstance(item, SquadExample):
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
|
||||
.format('X' if 'X' in kwargs else 'data')
|
||||
)
|
||||
inputs = data
|
||||
|
||||
# Tabular input
|
||||
elif 'question' in kwargs and 'context' in kwargs:
|
||||
if isinstance(kwargs['question'], str):
|
||||
kwargs['question'] = [kwargs['question']]
|
||||
|
||||
if isinstance(kwargs['context'], str):
|
||||
kwargs['context'] = [kwargs['context']]
|
||||
|
||||
inputs = [QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs['question'], kwargs['context'])]
|
||||
else:
|
||||
raise ValueError('Unknown arguments {}'.format(kwargs))
|
||||
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class QuestionAnsweringPipeline(Pipeline):
|
||||
"""
|
||||
Question Answering pipeline using ModelForQuestionAnswering head.
|
||||
@ -403,8 +460,9 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
else:
|
||||
return SquadExample(None, question, context, None, None, None)
|
||||
|
||||
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer]):
|
||||
super().__init__(model, tokenizer, args_parser=QuestionAnsweringPipeline.QuestionAnsweringArgumentHandler())
|
||||
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer], device: int = -1, **kwargs):
|
||||
super().__init__(model, tokenizer, args_parser=QuestionAnsweringArgumentHandler(),
|
||||
device=device, **kwargs)
|
||||
|
||||
def inputs_for_model(self, features: Union[SquadExample, List[SquadExample]]) -> Dict:
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user