From 0c88c856d592134ee5a9a636f9b73f40b91784b5 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 18 Dec 2019 18:18:16 +0100 Subject: [PATCH] Unnest QuestionAnsweringArgumentHandler --- transformers/pipelines.py | 62 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/transformers/pipelines.py b/transformers/pipelines.py index bcb4d9e054b..a10078b027c 100755 --- a/transformers/pipelines.py +++ b/transformers/pipelines.py @@ -333,6 +333,63 @@ class NerPipeline(Pipeline): return answers +class QuestionAnsweringArgumentHandler(ArgumentHandler): + """ + QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped + to internal SquadExample / SquadFeature structures. + + QuestionAnsweringArgumentHandler manages all the possible to create SquadExample from the command-line supplied + arguments. + """ + def __call__(self, *args, **kwargs): + # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating + if args is not None and len(args) > 0: + if len(args) == 1: + kwargs['X'] = args[0] + else: + kwargs['X'] = list(args) + + # Generic compatibility with sklearn and Keras + # Batched data + if 'X' in kwargs or 'data' in kwargs: + data = kwargs['X'] if 'X' in kwargs else kwargs['data'] + + if not isinstance(data, list): + data = [data] + + for i, item in enumerate(data): + if isinstance(item, dict): + if any(k not in item for k in ['question', 'context']): + raise KeyError('You need to provide a dictionary with keys {question:..., context:...}') + data[i] = QuestionAnsweringPipeline.create_sample(**item) + + elif isinstance(item, SquadExample): + continue + else: + raise ValueError( + '{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)' + .format('X' if 'X' in kwargs else 'data') + ) + inputs = data + + # Tabular input + elif 'question' in kwargs and 'context' in kwargs: + if isinstance(kwargs['question'], str): + kwargs['question'] = [kwargs['question']] + + if isinstance(kwargs['context'], str): + kwargs['context'] = [kwargs['context']] + + inputs = [QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs['question'], kwargs['context'])] + else: + raise ValueError('Unknown arguments {}'.format(kwargs)) + + if not isinstance(inputs, list): + inputs = [inputs] + + return inputs + + class QuestionAnsweringPipeline(Pipeline): """ Question Answering pipeline using ModelForQuestionAnswering head. @@ -403,8 +460,9 @@ class QuestionAnsweringPipeline(Pipeline): else: return SquadExample(None, question, context, None, None, None) - def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer]): - super().__init__(model, tokenizer, args_parser=QuestionAnsweringPipeline.QuestionAnsweringArgumentHandler()) + def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer], device: int = -1, **kwargs): + super().__init__(model, tokenizer, args_parser=QuestionAnsweringArgumentHandler(), + device=device, **kwargs) def inputs_for_model(self, features: Union[SquadExample, List[SquadExample]]) -> Dict: """