Unnest QuestionAnsweringArgumentHandler

This commit is contained in:
Morgan Funtowicz 2019-12-18 18:18:16 +01:00
parent e347725d8c
commit 0c88c856d5

View File

@ -333,6 +333,63 @@ class NerPipeline(Pipeline):
return answers
class QuestionAnsweringArgumentHandler(ArgumentHandler):
"""
QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped
to internal SquadExample / SquadFeature structures.
QuestionAnsweringArgumentHandler manages all the possible to create SquadExample from the command-line supplied
arguments.
"""
def __call__(self, *args, **kwargs):
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if args is not None and len(args) > 0:
if len(args) == 1:
kwargs['X'] = args[0]
else:
kwargs['X'] = list(args)
# Generic compatibility with sklearn and Keras
# Batched data
if 'X' in kwargs or 'data' in kwargs:
data = kwargs['X'] if 'X' in kwargs else kwargs['data']
if not isinstance(data, list):
data = [data]
for i, item in enumerate(data):
if isinstance(item, dict):
if any(k not in item for k in ['question', 'context']):
raise KeyError('You need to provide a dictionary with keys {question:..., context:...}')
data[i] = QuestionAnsweringPipeline.create_sample(**item)
elif isinstance(item, SquadExample):
continue
else:
raise ValueError(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.format('X' if 'X' in kwargs else 'data')
)
inputs = data
# Tabular input
elif 'question' in kwargs and 'context' in kwargs:
if isinstance(kwargs['question'], str):
kwargs['question'] = [kwargs['question']]
if isinstance(kwargs['context'], str):
kwargs['context'] = [kwargs['context']]
inputs = [QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs['question'], kwargs['context'])]
else:
raise ValueError('Unknown arguments {}'.format(kwargs))
if not isinstance(inputs, list):
inputs = [inputs]
return inputs
class QuestionAnsweringPipeline(Pipeline):
"""
Question Answering pipeline using ModelForQuestionAnswering head.
@ -403,8 +460,9 @@ class QuestionAnsweringPipeline(Pipeline):
else:
return SquadExample(None, question, context, None, None, None)
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer]):
super().__init__(model, tokenizer, args_parser=QuestionAnsweringPipeline.QuestionAnsweringArgumentHandler())
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer], device: int = -1, **kwargs):
super().__init__(model, tokenizer, args_parser=QuestionAnsweringArgumentHandler(),
device=device, **kwargs)
def inputs_for_model(self, features: Union[SquadExample, List[SquadExample]]) -> Dict:
"""